mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-01 01:50:39 +00:00
Compare commits
45 Commits
python-v0.
...
jack/sopho
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bff911a65d | ||
|
|
3a4cdb7aff | ||
|
|
142ac835d3 | ||
|
|
3f44f93e92 | ||
|
|
9dfa43a9de | ||
|
|
03e895fa5c | ||
|
|
c31e53088e | ||
|
|
434a5be187 | ||
|
|
78aa005093 | ||
|
|
6191542cfe | ||
|
|
6af3088b91 | ||
|
|
e73d4618d8 | ||
|
|
3d92106394 | ||
|
|
5810974b37 | ||
|
|
8b38500b07 | ||
|
|
fd0a3b97d0 | ||
|
|
b9f33ba1c9 | ||
|
|
d4f4fef3ba | ||
|
|
fbe6a5a3fd | ||
|
|
127054069a | ||
|
|
b20931b8f7 | ||
|
|
396d68e490 | ||
|
|
ad37f87387 | ||
|
|
e93476f0e0 | ||
|
|
2b41fce033 | ||
|
|
04948fc4f6 | ||
|
|
ff3c7111b9 | ||
|
|
10fecdf051 | ||
|
|
c9ae93a7fa | ||
|
|
05756f0bbf | ||
|
|
2a0945443e | ||
|
|
39e819b6a7 | ||
|
|
70126943ff | ||
|
|
e01777070d | ||
|
|
3878adc6dc | ||
|
|
3df3043563 | ||
|
|
8a5cd74e48 | ||
|
|
448d5ec20f | ||
|
|
8718345229 | ||
|
|
026fedc286 | ||
|
|
fe287dc98c | ||
|
|
411568b72c | ||
|
|
ebf8d55ede | ||
|
|
0ba70d96c3 | ||
|
|
0749532c3c |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.31.0-beta.1"
|
||||
current_version = "0.31.0-beta.4"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
175
Cargo.lock
generated
175
Cargo.lock
generated
@@ -157,9 +157,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.102"
|
||||
version = "1.0.103"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
|
||||
checksum = "2a4385e2e34eb35d6b3efe798b9eb88096925d87726c0798709bf56d9ed84af3"
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
@@ -1297,15 +1297,6 @@ version = "2.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3"
|
||||
|
||||
[[package]]
|
||||
name = "bitpacking"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96a7139abd3d9cebf8cd6f920a389cf3dc9576172e32f4563f188cae3c3eb019"
|
||||
dependencies = [
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
@@ -1759,7 +1750,7 @@ version = "3.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3023,7 +3014,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3186,9 +3177,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "1.0.1"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef"
|
||||
checksum = "900d271a03799a1ee8d1ca9b19893b48ca674a9284fefcfb85f05e74ed314217"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
@@ -3196,9 +3187,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.10"
|
||||
version = "0.11.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a"
|
||||
checksum = "de671bd27a75a797dc9ae289ba1e77276e75e2026408aab65185384e2d5cd3f6"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@@ -3240,7 +3231,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3432,8 +3423,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.4",
|
||||
@@ -4475,7 +4466,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4569,7 +4560,7 @@ dependencies = [
|
||||
"portable-atomic-util",
|
||||
"serde_core",
|
||||
"wasm-bindgen",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4735,8 +4726,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"arrow",
|
||||
@@ -4754,7 +4745,6 @@ dependencies = [
|
||||
"async_cell",
|
||||
"aws-credential-types",
|
||||
"aws-sdk-dynamodb",
|
||||
"bitpacking",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -4771,8 +4761,9 @@ dependencies = [
|
||||
"futures",
|
||||
"half",
|
||||
"humantime",
|
||||
"itertools 0.13.0",
|
||||
"itertools 0.14.0",
|
||||
"lance-arrow",
|
||||
"lance-bitpacking",
|
||||
"lance-core",
|
||||
"lance-datafusion",
|
||||
"lance-encoding",
|
||||
@@ -4810,8 +4801,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4832,7 +4823,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "lance-arrow-scalar"
|
||||
version = "58.0.0"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4846,7 +4837,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "lance-arrow-stats"
|
||||
version = "58.0.0"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4855,18 +4846,19 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"crunchy",
|
||||
"paste",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4878,7 +4870,7 @@ dependencies = [
|
||||
"datafusion-common",
|
||||
"datafusion-sql",
|
||||
"futures",
|
||||
"itertools 0.13.0",
|
||||
"itertools 0.14.0",
|
||||
"lance-arrow",
|
||||
"lance-derive",
|
||||
"libc",
|
||||
@@ -4904,8 +4896,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4935,8 +4927,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4953,8 +4945,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-derive"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -4963,8 +4955,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4980,7 +4972,7 @@ dependencies = [
|
||||
"futures",
|
||||
"hex",
|
||||
"hyperloglogplus",
|
||||
"itertools 0.13.0",
|
||||
"itertools 0.14.0",
|
||||
"lance-arrow",
|
||||
"lance-bitpacking",
|
||||
"lance-core",
|
||||
@@ -4999,8 +4991,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -5030,8 +5022,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"arrow",
|
||||
@@ -5043,7 +5035,6 @@ dependencies = [
|
||||
"async-channel",
|
||||
"async-recursion",
|
||||
"async-trait",
|
||||
"bitpacking",
|
||||
"bitvec",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -5056,11 +5047,12 @@ dependencies = [
|
||||
"fst",
|
||||
"futures",
|
||||
"half",
|
||||
"itertools 0.13.0",
|
||||
"itertools 0.14.0",
|
||||
"jieba-rs",
|
||||
"jsonb",
|
||||
"lance-arrow",
|
||||
"lance-arrow-stats",
|
||||
"lance-bitpacking",
|
||||
"lance-core",
|
||||
"lance-datafusion",
|
||||
"lance-datagen",
|
||||
@@ -5096,8 +5088,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -5138,8 +5130,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5155,8 +5147,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5168,8 +5160,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
@@ -5223,15 +5215,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-select"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-schema",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"itertools 0.13.0",
|
||||
"itertools 0.14.0",
|
||||
"lance-core",
|
||||
"roaring",
|
||||
"tracing",
|
||||
@@ -5239,8 +5231,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -5279,8 +5271,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -5293,8 +5285,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-tokenizer"
|
||||
version = "9.0.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
dependencies = [
|
||||
"icu_segmenter",
|
||||
"jieba-rs",
|
||||
@@ -5307,7 +5299,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.31.0-beta.1"
|
||||
version = "0.31.0-beta.4"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -5384,13 +5376,14 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"url",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.31.0-beta.1"
|
||||
version = "0.31.0-beta.4"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5415,7 +5408,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.34.0-beta.1"
|
||||
version = "0.34.0-beta.4"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5648,9 +5641,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.32"
|
||||
version = "0.4.33"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a"
|
||||
checksum = "0ceec5bc11778974d1bcb055b18002eba7f4b3518b6a0081b3af5f21666da9ad"
|
||||
|
||||
[[package]]
|
||||
name = "loom"
|
||||
@@ -5958,9 +5951,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "napi"
|
||||
version = "3.9.3"
|
||||
version = "3.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbd9f9295f3ff5921e78a71222c3361a8216f7760b1a99a6ad4e8441de18bbb9"
|
||||
checksum = "b41bda2ac390efb5e8d22025d925ccc3f3807d8c1bea6d19b36127247c4b8f83"
|
||||
dependencies = [
|
||||
"bitflags 2.11.1",
|
||||
"chrono",
|
||||
@@ -5983,9 +5976,9 @@ checksum = "c9c366d2c8c60b86fa632df75f745509b52f9128f91a6bad4c796e44abb505e1"
|
||||
|
||||
[[package]]
|
||||
name = "napi-derive"
|
||||
version = "3.5.6"
|
||||
version = "3.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89b3f766e04667e6da0e181e2da4f85475d5a6513b7cf6a80bea184e224a5b42"
|
||||
checksum = "61d66f70256ad5aef58659966064471d0ad90e2897bc36a5a5e0389c85aabc1e"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"ctor 1.0.5",
|
||||
@@ -5997,9 +5990,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "napi-derive-backend"
|
||||
version = "5.0.4"
|
||||
version = "5.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d5af30503edf933ce7377cf6d4c877a62b0f1107ea05585f1b5e430e88d5baf"
|
||||
checksum = "81b4b08f15eed7a2a20c3f4c6314013fc3ac890a3afa9892b594485299ebdb2d"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"proc-macro2",
|
||||
@@ -6092,7 +6085,7 @@ version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7407,8 +7400,8 @@ version = "0.14.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"itertools 0.14.0",
|
||||
"heck 0.4.1",
|
||||
"itertools 0.11.0",
|
||||
"log",
|
||||
"multimap",
|
||||
"petgraph",
|
||||
@@ -7427,7 +7420,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.11.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -7661,7 +7654,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"socket2 0.6.3",
|
||||
"tracing",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8401,7 +8394,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8472,7 +8465,7 @@ dependencies = [
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-root-certs",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -9034,7 +9027,7 @@ version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -9046,7 +9039,7 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -9479,7 +9472,7 @@ dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10127,9 +10120,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.23.3"
|
||||
version = "1.23.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7"
|
||||
checksum = "bf80a72845275afea99e7f2b434723d3bc7e38470fcd1c7ed39a599c73319a53"
|
||||
dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"js-sys",
|
||||
@@ -10414,7 +10407,7 @@ version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -13,20 +13,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-core = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-datagen = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-file = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-io = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-index = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-linalg = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-namespace = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-table = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-testing = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-datafusion = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-encoding = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-arrow = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "58.0.0", optional = false }
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.31.0-beta.1</version>
|
||||
<version>0.31.0-beta.4</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
29
docs/src/js/enumerations/OAuthFlowType.md
Normal file
29
docs/src/js/enumerations/OAuthFlowType.md
Normal file
@@ -0,0 +1,29 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / OAuthFlowType
|
||||
|
||||
# Enumeration: OAuthFlowType
|
||||
|
||||
OAuth authentication flow types.
|
||||
|
||||
## Enumeration Members
|
||||
|
||||
### AzureManagedIdentity
|
||||
|
||||
```ts
|
||||
AzureManagedIdentity: "azure_managed_identity";
|
||||
```
|
||||
|
||||
Azure Managed Identity via IMDS.
|
||||
|
||||
***
|
||||
|
||||
### ClientCredentials
|
||||
|
||||
```ts
|
||||
ClientCredentials: "client_credentials";
|
||||
```
|
||||
|
||||
Client Credentials grant (service-to-service / M2M).
|
||||
@@ -12,6 +12,7 @@
|
||||
## Enumerations
|
||||
|
||||
- [FullTextQueryType](enumerations/FullTextQueryType.md)
|
||||
- [OAuthFlowType](enumerations/OAuthFlowType.md)
|
||||
- [Occur](enumerations/Occur.md)
|
||||
- [Operator](enumerations/Operator.md)
|
||||
|
||||
@@ -85,6 +86,8 @@
|
||||
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
|
||||
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
|
||||
- [MergeResult](interfaces/MergeResult.md)
|
||||
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
|
||||
- [OAuthConfig](interfaces/OAuthConfig.md)
|
||||
- [OpenTableOptions](interfaces/OpenTableOptions.md)
|
||||
- [OptimizeOptions](interfaces/OptimizeOptions.md)
|
||||
- [OptimizeStats](interfaces/OptimizeStats.md)
|
||||
|
||||
@@ -64,6 +64,19 @@ client used by manifest-enabled native connections.
|
||||
|
||||
***
|
||||
|
||||
### oauthConfig?
|
||||
|
||||
```ts
|
||||
optional oauthConfig: NativeOAuthConfig;
|
||||
```
|
||||
|
||||
(For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
and refresh are handled entirely in Rust. TypeScript users should pass
|
||||
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
|
||||
***
|
||||
|
||||
### readConsistencyInterval?
|
||||
|
||||
```ts
|
||||
|
||||
88
docs/src/js/interfaces/NativeOAuthConfig.md
Normal file
88
docs/src/js/interfaces/NativeOAuthConfig.md
Normal file
@@ -0,0 +1,88 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / NativeOAuthConfig
|
||||
|
||||
# Interface: NativeOAuthConfig
|
||||
|
||||
OAuth configuration for LanceDB authentication.
|
||||
|
||||
This is the generated napi-rs binding shape. TypeScript users should prefer
|
||||
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
|
||||
## Properties
|
||||
|
||||
### clientId
|
||||
|
||||
```ts
|
||||
clientId: string;
|
||||
```
|
||||
|
||||
Application / Client ID.
|
||||
|
||||
***
|
||||
|
||||
### clientSecret?
|
||||
|
||||
```ts
|
||||
optional clientSecret: string;
|
||||
```
|
||||
|
||||
Client secret (required for client_credentials).
|
||||
|
||||
***
|
||||
|
||||
### flow?
|
||||
|
||||
```ts
|
||||
optional flow: string;
|
||||
```
|
||||
|
||||
Authentication flow: "client_credentials" or "azure_managed_identity"
|
||||
|
||||
***
|
||||
|
||||
### issuerUrl
|
||||
|
||||
```ts
|
||||
issuerUrl: string;
|
||||
```
|
||||
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
|
||||
***
|
||||
|
||||
### managedIdentityClientId?
|
||||
|
||||
```ts
|
||||
optional managedIdentityClientId: string;
|
||||
```
|
||||
|
||||
Client ID for user-assigned managed identity (azure_managed_identity).
|
||||
|
||||
***
|
||||
|
||||
### refreshBufferSecs?
|
||||
|
||||
```ts
|
||||
optional refreshBufferSecs: number;
|
||||
```
|
||||
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
***
|
||||
|
||||
### scopes
|
||||
|
||||
```ts
|
||||
scopes: string[];
|
||||
```
|
||||
|
||||
OAuth scopes to request. For Azure managed identity, exactly one scope
|
||||
or resource is required. For example: `["api://{app_id}/.default"]`
|
||||
111
docs/src/js/interfaces/OAuthConfig.md
Normal file
111
docs/src/js/interfaces/OAuthConfig.md
Normal file
@@ -0,0 +1,111 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / OAuthConfig
|
||||
|
||||
# Interface: OAuthConfig
|
||||
|
||||
OAuth configuration for LanceDB authentication.
|
||||
|
||||
This is the public TypeScript OAuth configuration type. The generated
|
||||
`NativeOAuthConfig` type has the same runtime shape but is an implementation
|
||||
detail of the napi-rs binding.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
This config is passed through to Rust via napi-rs.
|
||||
|
||||
## Examples
|
||||
|
||||
```typescript
|
||||
const config: OAuthConfig = {
|
||||
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
clientId: "app-id",
|
||||
clientSecret: "secret",
|
||||
scopes: ["api://lancedb-api/.default"],
|
||||
};
|
||||
```
|
||||
|
||||
```typescript
|
||||
const config: OAuthConfig = {
|
||||
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
clientId: "app-id",
|
||||
scopes: ["api://lancedb-api/.default"],
|
||||
flow: OAuthFlowType.AzureManagedIdentity,
|
||||
};
|
||||
```
|
||||
|
||||
## Properties
|
||||
|
||||
### clientId
|
||||
|
||||
```ts
|
||||
clientId: string;
|
||||
```
|
||||
|
||||
Application / Client ID.
|
||||
|
||||
***
|
||||
|
||||
### clientSecret?
|
||||
|
||||
```ts
|
||||
optional clientSecret: string;
|
||||
```
|
||||
|
||||
Client secret (required for ClientCredentials).
|
||||
|
||||
***
|
||||
|
||||
### flow?
|
||||
|
||||
```ts
|
||||
optional flow: OAuthFlowType;
|
||||
```
|
||||
|
||||
Authentication flow (default: ClientCredentials).
|
||||
|
||||
***
|
||||
|
||||
### issuerUrl
|
||||
|
||||
```ts
|
||||
issuerUrl: string;
|
||||
```
|
||||
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
|
||||
***
|
||||
|
||||
### managedIdentityClientId?
|
||||
|
||||
```ts
|
||||
optional managedIdentityClientId: string;
|
||||
```
|
||||
|
||||
Client ID for user-assigned managed identity (AzureManagedIdentity).
|
||||
|
||||
***
|
||||
|
||||
### refreshBufferSecs?
|
||||
|
||||
```ts
|
||||
optional refreshBufferSecs: number;
|
||||
```
|
||||
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
***
|
||||
|
||||
### scopes
|
||||
|
||||
```ts
|
||||
scopes: string[];
|
||||
```
|
||||
|
||||
OAuth scopes to request.
|
||||
For Azure managed identity, exactly one scope or resource is required.
|
||||
For example: `["api://{app_id}/.default"]`
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.31.0-beta.1</version>
|
||||
<version>0.31.0-beta.4</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.31.0-beta.1</version>
|
||||
<version>0.31.0-beta.4</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>9.0.0-beta.2</lance-core.version>
|
||||
<lance-core.version>9.0.0-beta.10</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.31.0-beta.1"
|
||||
version = "0.31.0-beta.4"
|
||||
publish = false
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
@@ -52,6 +52,7 @@ export {
|
||||
SplitHashOptions,
|
||||
SplitSequentialOptions,
|
||||
ShuffleOptions,
|
||||
OAuthConfig as NativeOAuthConfig,
|
||||
} from "./native.js";
|
||||
|
||||
export {
|
||||
@@ -130,6 +131,8 @@ export {
|
||||
TokenResponse,
|
||||
} from "./header";
|
||||
|
||||
export { OAuthConfig, OAuthFlowType } from "./oauth";
|
||||
|
||||
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
|
||||
76
nodejs/lancedb/oauth.ts
Normal file
76
nodejs/lancedb/oauth.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
/**
|
||||
* OAuth authentication flow types.
|
||||
*/
|
||||
export enum OAuthFlowType {
|
||||
/** Client Credentials grant (service-to-service / M2M). */
|
||||
ClientCredentials = "client_credentials",
|
||||
/** Azure Managed Identity via IMDS. */
|
||||
AzureManagedIdentity = "azure_managed_identity",
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth configuration for LanceDB authentication.
|
||||
*
|
||||
* This is the public TypeScript OAuth configuration type. The generated
|
||||
* `NativeOAuthConfig` type has the same runtime shape but is an implementation
|
||||
* detail of the napi-rs binding.
|
||||
*
|
||||
* All token acquisition and refresh is handled in the Rust layer.
|
||||
* This config is passed through to Rust via napi-rs.
|
||||
*
|
||||
* @example Client Credentials (service-to-service):
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* clientSecret: "secret",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* };
|
||||
* ```
|
||||
*
|
||||
* @example Azure Managed Identity:
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* flow: OAuthFlowType.AzureManagedIdentity,
|
||||
* };
|
||||
* ```
|
||||
*/
|
||||
export interface OAuthConfig {
|
||||
/**
|
||||
* OIDC issuer URL or OAuth authority URL.
|
||||
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
*/
|
||||
issuerUrl: string;
|
||||
|
||||
/** Application / Client ID. */
|
||||
clientId: string;
|
||||
|
||||
/**
|
||||
* OAuth scopes to request.
|
||||
* For Azure managed identity, exactly one scope or resource is required.
|
||||
* For example: `["api://{app_id}/.default"]`
|
||||
*/
|
||||
scopes: string[];
|
||||
|
||||
/** Authentication flow (default: ClientCredentials). */
|
||||
flow?: OAuthFlowType;
|
||||
|
||||
/** Client secret (required for ClientCredentials). */
|
||||
clientSecret?: string;
|
||||
|
||||
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
|
||||
managedIdentityClientId?: string;
|
||||
|
||||
/**
|
||||
* Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
* Keep this well below the token TTL; if it is greater than or equal to
|
||||
* the TTL, each request refreshes the token.
|
||||
*/
|
||||
refreshBufferSecs?: number;
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.31.0-beta.1",
|
||||
"version": "0.31.0-beta.4",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -112,6 +112,12 @@ impl Connection {
|
||||
|
||||
builder = builder.client_config(rust_config);
|
||||
|
||||
if let Some(oauth_config) = options.oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig =
|
||||
oauth_config.try_into().default_error()?;
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
|
||||
if let Some(api_key) = options.api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
}
|
||||
|
||||
@@ -65,6 +65,11 @@ pub struct ConnectionOptions {
|
||||
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
|
||||
/// for testing purposes.
|
||||
pub host_override: Option<String>,
|
||||
/// (For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
/// authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
/// and refresh are handled entirely in Rust. TypeScript users should pass
|
||||
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
pub oauth_config: Option<remote::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use lancedb::error::Error;
|
||||
use napi_derive::*;
|
||||
|
||||
/// Timeout configuration for remote HTTP client.
|
||||
@@ -140,6 +141,84 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
///
|
||||
/// This is the generated napi-rs binding shape. TypeScript users should prefer
|
||||
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
///
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
#[napi(object)]
|
||||
#[derive(Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
/// OAuth scopes to request. For Azure managed identity, exactly one scope
|
||||
/// or resource is required. For example: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
/// Authentication flow: "client_credentials" or "azure_managed_identity"
|
||||
pub flow: Option<String>,
|
||||
/// Client secret (required for client_credentials).
|
||||
pub client_secret: Option<String>,
|
||||
/// Client ID for user-assigned managed identity (azure_managed_identity).
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
/// Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
/// Keep this well below the token TTL; if it is greater than or equal to
|
||||
/// the TTL, each request refreshes the token.
|
||||
pub refresh_buffer_secs: Option<u32>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthConfig")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("flow", &self.flow)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_deref().map(|_| "<redacted>"),
|
||||
)
|
||||
.field(
|
||||
"managed_identity_client_id",
|
||||
&self.managed_identity_client_id,
|
||||
)
|
||||
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(config: OAuthConfig) -> Result<Self, Self::Error> {
|
||||
use lancedb::remote::oauth::OAuthFlow;
|
||||
|
||||
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
|
||||
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: config.managed_identity_client_id,
|
||||
},
|
||||
other => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("Unknown OAuth flow type: {other}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
issuer_url: config.issuer_url,
|
||||
client_id: config.client_id,
|
||||
client_secret: config.client_secret,
|
||||
scopes: config.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
@@ -156,3 +235,45 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_unknown_oauth_flow_returns_invalid_input() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: Some("typo".to_string()),
|
||||
client_secret: None,
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = lancedb::remote::oauth::OAuthConfig::try_from(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "Unknown OAuth flow type: typo"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_config_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: Some("client_credentials".to_string()),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let debug = format!("{config:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.34.0-beta.2"
|
||||
current_version = "0.34.0-beta.4"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.34.0-beta.2"
|
||||
version = "0.34.0-beta.4"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
|
||||
@@ -17,6 +17,17 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .expr import Expr, col, lit, func
|
||||
from .udf import (
|
||||
udf,
|
||||
table_udf,
|
||||
Udf,
|
||||
JobHandle,
|
||||
JobFailedError,
|
||||
MaterializedView,
|
||||
AsyncJobHandle,
|
||||
AsyncMaterializedView,
|
||||
)
|
||||
from .lineage import Lineage, Node, Edge, FunctionRef
|
||||
from .schema import vector
|
||||
from .table import AsyncTable, Table
|
||||
from ._lancedb import Session
|
||||
@@ -89,6 +100,8 @@ def connect(
|
||||
If presented, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||
OAuth configuration is currently supported only by ``connect_async``;
|
||||
synchronous LanceDB Cloud connections require an API key.
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
@@ -340,6 +353,7 @@ async def connect_async(
|
||||
session: Optional[Session] = None,
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config=None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -389,6 +403,10 @@ async def connect_async(
|
||||
namespace_client_properties : dict, optional
|
||||
Additional directory namespace client properties to use with
|
||||
``manifest_enabled=True``.
|
||||
oauth_config : OAuthConfig, optional
|
||||
OAuth configuration for LanceDB Cloud/Enterprise. This is supported by
|
||||
``connect_async`` only; synchronous ``connect`` uses API key
|
||||
authentication for ``db://`` URIs.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -435,11 +453,24 @@ async def connect_async(
|
||||
session,
|
||||
manifest_enabled,
|
||||
namespace_client_properties,
|
||||
oauth_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"udf",
|
||||
"table_udf",
|
||||
"Udf",
|
||||
"JobHandle",
|
||||
"JobFailedError",
|
||||
"MaterializedView",
|
||||
"AsyncJobHandle",
|
||||
"AsyncMaterializedView",
|
||||
"Lineage",
|
||||
"Node",
|
||||
"Edge",
|
||||
"FunctionRef",
|
||||
"connect",
|
||||
"connect_async",
|
||||
"connect_namespace",
|
||||
|
||||
@@ -280,6 +280,7 @@ async def connect(
|
||||
session: Optional[Session],
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config: Optional[Any] = None,
|
||||
) -> Connection: ...
|
||||
|
||||
class RecordBatchStream:
|
||||
|
||||
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from ._lancedb import Session
|
||||
from .udf import MaterializedView, AsyncMaterializedView
|
||||
|
||||
from .namespace_utils import (
|
||||
_normalize_create_namespace_mode,
|
||||
@@ -562,6 +563,259 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError("serialize is not supported for this connection type")
|
||||
|
||||
# -- Derived compute: functions, materialized views, jobs -------------
|
||||
# Server-backed features (LanceDB Enterprise / Cloud); local
|
||||
# connections raise NotImplementedError for now.
|
||||
|
||||
def create_function(
|
||||
self,
|
||||
name,
|
||||
language: str = "python",
|
||||
return_type: Optional[str] = None,
|
||||
body: Optional[str] = None,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Register a UDF (CREATE FUNCTION).
|
||||
|
||||
Pass a ``@udf`` / ``@table_udf``-decorated function (preferred):
|
||||
|
||||
db.create_function(embed)
|
||||
|
||||
or the explicit fields:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str or Udf
|
||||
A decorated UDF object, or the function name.
|
||||
language: str
|
||||
Implementation language (currently "python").
|
||||
return_type: str
|
||||
SQL return type, e.g. "FLOAT", "FLOAT[1536]",
|
||||
"STRUCT(a FLOAT, b VARCHAR)", "TABLE(chunk VARCHAR, idx INT)".
|
||||
body: str
|
||||
Function body: source text, or base64 cloudpickle bytes when
|
||||
options["body_format"] == "cloudpickle".
|
||||
options: dict, optional
|
||||
input_columns, pip, num_gpus, batch_size, timeout,
|
||||
error_policy, docker_image, body_format, ...
|
||||
replace: bool
|
||||
Drop an existing function of the same name first.
|
||||
"""
|
||||
from .udf import Udf
|
||||
|
||||
if isinstance(name, Udf):
|
||||
req = name.create_request()
|
||||
name, language, return_type, body, options = (
|
||||
req["name"],
|
||||
req["language"],
|
||||
req["return_type"],
|
||||
req["body"],
|
||||
req["options"],
|
||||
)
|
||||
if replace:
|
||||
try:
|
||||
self.drop_function(name)
|
||||
except Exception:
|
||||
pass
|
||||
LOOP.run(self._conn.create_function(name, language, return_type, body, options))
|
||||
|
||||
def list_functions(self):
|
||||
"""List registered functions (SHOW FUNCTIONS)."""
|
||||
return LOOP.run(self._conn.list_functions())
|
||||
|
||||
def drop_function(self, name: str):
|
||||
"""Drop a registered function (DROP FUNCTION)."""
|
||||
LOOP.run(self._conn.drop_function(name))
|
||||
|
||||
def create_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
source=None,
|
||||
select=None,
|
||||
*,
|
||||
query: Optional[str] = None,
|
||||
where: Optional[str] = None,
|
||||
auto_refresh: bool = False,
|
||||
with_no_data: bool = False,
|
||||
replace: bool = False,
|
||||
partition_by: Optional[str] = None,
|
||||
) -> "MaterializedView":
|
||||
"""Create a materialized view (CREATE MATERIALIZED VIEW); returns a
|
||||
`MaterializedView` handle (``.wait()`` blocks until it is populated).
|
||||
|
||||
Two ways to specify the view body:
|
||||
|
||||
- ergonomic: pass ``source`` (a table name or table) and ``select``
|
||||
items -- column names, expression strings ("embed(body)"),
|
||||
(alias, expression) tuples, or ``@udf`` / ``@table_udf`` objects.
|
||||
The SELECT is assembled and parsed server-side (one parser, shared
|
||||
with SQL).
|
||||
- raw: pass ``query=`` with a full SELECT, e.g.
|
||||
"SELECT id, embed(body) AS vec FROM articles WHERE id > 1".
|
||||
|
||||
`partition_by` partitions the view's (single) table function on a source
|
||||
column. If that column has an IVF vector index the server partitions by
|
||||
its index clusters (image-dedup style); otherwise it groups by distinct
|
||||
value. (Geneva's `partition_by` and `partition_by_indexed_column` unify
|
||||
here -- the engine picks the strategy from the column.)
|
||||
"""
|
||||
from .udf import build_view_query, MaterializedView
|
||||
|
||||
if query is None:
|
||||
if source is None or select is None:
|
||||
raise ValueError(
|
||||
"create_materialized_view needs either query= or both "
|
||||
"source and select"
|
||||
)
|
||||
query = build_view_query(source, select)
|
||||
if where:
|
||||
query += f" WHERE {where}"
|
||||
if replace:
|
||||
self._drop_view_if_exists(name)
|
||||
job_id = LOOP.run(
|
||||
self._conn.create_materialized_view(
|
||||
name,
|
||||
query=query,
|
||||
auto_refresh=auto_refresh,
|
||||
with_no_data=with_no_data,
|
||||
partition_by=partition_by,
|
||||
)
|
||||
)
|
||||
return MaterializedView(self, name, job_id=job_id)
|
||||
|
||||
def _drop_view_if_exists(self, name: str) -> None:
|
||||
# `replace=True` is "drop if present"; only a not-found error is
|
||||
# benign here. Anything else (perms, server fault) must surface rather
|
||||
# than be masked by a later create failure.
|
||||
try:
|
||||
self.drop_materialized_view(name)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "not found" not in msg and "does not exist" not in msg:
|
||||
raise
|
||||
|
||||
def job(self, job_id: str):
|
||||
"""A `JobHandle` for reconnecting to an inflight job by id -- e.g. an
|
||||
id you stored, or one returned from the SQL / REST surface. Submit
|
||||
methods (`refresh_column`, `MaterializedView.refresh`) already return a
|
||||
handle directly, so you do not need this to wait on a fresh submission."""
|
||||
from .udf import JobHandle
|
||||
|
||||
return JobHandle(self, job_id)
|
||||
|
||||
def lineage(
|
||||
self,
|
||||
table: str,
|
||||
column: Optional[str] = None,
|
||||
*,
|
||||
direction: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
):
|
||||
"""Derived-compute lineage of a table/view, or one of its columns:
|
||||
upstream sources, downstream dependents, and the function version +
|
||||
location that produced each derived column (with a drift flag). Returns
|
||||
a `Lineage`. `direction` is "upstream" | "downstream" | "both" (server
|
||||
default both); `depth` limits column-hops (transitive when omitted)."""
|
||||
# `self._conn` is the AsyncConnection; drive its async `lineage`
|
||||
# (which parses the JSON) on the loop, mirroring create_materialized_view.
|
||||
return LOOP.run(
|
||||
self._conn.lineage(table, column, direction=direction, depth=depth)
|
||||
)
|
||||
|
||||
def _refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Internal: submit a materialized-view refresh, return the job id.
|
||||
The public surface is ``MaterializedView.refresh()`` (which returns a
|
||||
`JobHandle`); this stays private so refresh is only reached through the
|
||||
handle.
|
||||
|
||||
``full=True`` forces a full rebuild (recompute and replace every row)
|
||||
instead of the default incremental refresh.
|
||||
"""
|
||||
return LOOP.run(
|
||||
self._conn._refresh_materialized_view(
|
||||
name,
|
||||
full=full,
|
||||
src_version=src_version,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
)
|
||||
|
||||
def explain_refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
):
|
||||
"""Plan a refresh without running it (EXPLAIN REFRESH). Returns a
|
||||
plan with .has_work / .source_version / .last_refreshed_version /
|
||||
.full_refresh / .rebuild / .units_total. `full=True` plans a full
|
||||
rebuild (incremental planning needs stable row IDs on the source)."""
|
||||
return LOOP.run(
|
||||
self._conn.explain_refresh_materialized_view(
|
||||
name, full=full, src_version=src_version
|
||||
)
|
||||
)
|
||||
|
||||
def alter_materialized_view(self, name: str, *, auto_refresh: bool):
|
||||
"""Update a materialized view's options (ALTER MATERIALIZED VIEW)."""
|
||||
LOOP.run(self._conn.alter_materialized_view(name, auto_refresh=auto_refresh))
|
||||
|
||||
def drop_materialized_view(self, name: str):
|
||||
"""Drop a materialized view definition (DROP MATERIALIZED VIEW)."""
|
||||
LOOP.run(self._conn.drop_materialized_view(name))
|
||||
|
||||
def list_materialized_views(self):
|
||||
"""List registered materialized view definitions."""
|
||||
return LOOP.run(self._conn.list_materialized_views())
|
||||
|
||||
def list_jobs(self):
|
||||
"""List inflight server-side jobs across the database's tables."""
|
||||
return LOOP.run(self._conn.list_jobs())
|
||||
|
||||
def get_job(self, job_id: str, table: "str | None" = None):
|
||||
"""Look up one server-side job by id (the wait()/status poll path).
|
||||
|
||||
Passing ``table`` (the job's table) lets the server answer with an O(1)
|
||||
single-node read instead of scanning the database's active jobs.
|
||||
Returns the job's status, or None if it's unknown or no longer active.
|
||||
"""
|
||||
return LOOP.run(self._conn.get_job(job_id, table))
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
"""Cancel an inflight server-side job by id (CANCEL JOB).
|
||||
|
||||
Returns True if a matching inflight job was found and flagged for
|
||||
cancellation, False if none was inflight (already finished or
|
||||
unknown id) -- cancellation is best-effort.
|
||||
"""
|
||||
return LOOP.run(self._conn.cancel_job(job_id))
|
||||
|
||||
def job_history(self, job_id: "str | None" = None):
|
||||
"""Durable history of completed server-side jobs (SHOW JOB HISTORY).
|
||||
|
||||
Pass ``job_id`` to narrow to a single job. Unlike :meth:`list_jobs`
|
||||
(live, inflight) these are the terminal records.
|
||||
"""
|
||||
return LOOP.run(self._conn.job_history(job_id))
|
||||
|
||||
def errors(self, job_id: "str | None" = None, table: "str | None" = None):
|
||||
"""Per-row UDF errors recorded by ``error_policy=skip`` (SHOW ERRORS),
|
||||
optionally filtered by ``job_id`` and/or ``table``.
|
||||
"""
|
||||
return LOOP.run(self._conn.errors(job_id, table))
|
||||
|
||||
|
||||
class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
@@ -1787,6 +2041,200 @@ class AsyncConnection(object):
|
||||
)
|
||||
return AsyncTable(table)
|
||||
|
||||
# -- Derived compute: functions, materialized views, jobs -------------
|
||||
# Server-backed features (LanceDB Enterprise / Cloud); local
|
||||
# connections raise NotImplementedError for now.
|
||||
|
||||
async def create_function(
|
||||
self,
|
||||
name,
|
||||
language: str = "python",
|
||||
return_type: Optional[str] = None,
|
||||
body: Optional[str] = None,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Register a UDF (CREATE FUNCTION). Accepts a ``@udf``/``@table_udf``
|
||||
object (preferred) or the explicit (name, language, return_type, body,
|
||||
options)."""
|
||||
from .udf import Udf
|
||||
|
||||
if isinstance(name, Udf):
|
||||
req = name.create_request()
|
||||
name, language, return_type, body, options = (
|
||||
req["name"],
|
||||
req["language"],
|
||||
req["return_type"],
|
||||
req["body"],
|
||||
req["options"],
|
||||
)
|
||||
if replace:
|
||||
try:
|
||||
await self.drop_function(name)
|
||||
except Exception:
|
||||
pass
|
||||
await self._inner.create_function(name, language, return_type, body, options)
|
||||
|
||||
async def list_functions(self):
|
||||
"""List registered functions (SHOW FUNCTIONS)."""
|
||||
return await self._inner.list_functions()
|
||||
|
||||
async def drop_function(self, name: str):
|
||||
"""Drop a registered function (DROP FUNCTION)."""
|
||||
await self._inner.drop_function(name)
|
||||
|
||||
async def create_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
source=None,
|
||||
select=None,
|
||||
*,
|
||||
query: Optional[str] = None,
|
||||
where: Optional[str] = None,
|
||||
auto_refresh: bool = False,
|
||||
with_no_data: bool = False,
|
||||
replace: bool = False,
|
||||
partition_by: Optional[str] = None,
|
||||
) -> "AsyncMaterializedView":
|
||||
"""Create a materialized view; returns an `AsyncMaterializedView`
|
||||
handle (``.wait()`` blocks until populated). Pass either ``query=`` (a
|
||||
full SELECT) or ``source`` + ``select`` items; `partition_by`
|
||||
partitions the view's table function on a source column (index-cluster
|
||||
if the column is IVF-indexed, else distinct-value). See the sync
|
||||
method for the select grammar."""
|
||||
from .udf import build_view_query, AsyncMaterializedView
|
||||
|
||||
if query is None:
|
||||
if source is None or select is None:
|
||||
raise ValueError(
|
||||
"create_materialized_view needs either query= or both "
|
||||
"source and select"
|
||||
)
|
||||
query = build_view_query(source, select)
|
||||
if where:
|
||||
query += f" WHERE {where}"
|
||||
if replace:
|
||||
try:
|
||||
await self.drop_materialized_view(name)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "not found" not in msg and "does not exist" not in msg:
|
||||
raise
|
||||
job_id = await self._inner.create_materialized_view(
|
||||
name,
|
||||
query,
|
||||
auto_refresh=auto_refresh,
|
||||
with_no_data=with_no_data,
|
||||
partition_by=partition_by,
|
||||
)
|
||||
return AsyncMaterializedView(self, name, job_id=job_id)
|
||||
|
||||
def job(self, job_id: str):
|
||||
"""An `AsyncJobHandle` for reconnecting to an inflight job by id (a
|
||||
stored id, or one from the SQL / REST surface). Submit methods already
|
||||
return a handle, so this is only needed to re-attach to an existing
|
||||
job."""
|
||||
from .udf import AsyncJobHandle
|
||||
|
||||
return AsyncJobHandle(self, job_id)
|
||||
|
||||
async def lineage(
|
||||
self,
|
||||
table: str,
|
||||
column: Optional[str] = None,
|
||||
*,
|
||||
direction: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
):
|
||||
"""Derived-compute lineage of a table/view (or column). See the sync
|
||||
`Connection.lineage`. Returns a `Lineage`."""
|
||||
from .lineage import Lineage
|
||||
|
||||
raw = await self._inner.table_lineage(table, column, direction, depth)
|
||||
return Lineage.from_json(raw)
|
||||
|
||||
async def _refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Internal: submit a refresh, return the job id. The public surface is
|
||||
``AsyncMaterializedView.refresh()`` (returns an `AsyncJobHandle`).
|
||||
|
||||
``full=True`` forces a full rebuild (recompute and replace every row)
|
||||
instead of the default incremental refresh.
|
||||
"""
|
||||
return await self._inner.refresh_materialized_view(
|
||||
name,
|
||||
full=full,
|
||||
src_version=src_version,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
async def explain_refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
):
|
||||
"""Plan a refresh without running it (EXPLAIN REFRESH)."""
|
||||
return await self._inner.explain_refresh_materialized_view(
|
||||
name, full=full, src_version=src_version
|
||||
)
|
||||
|
||||
async def alter_materialized_view(self, name: str, *, auto_refresh: bool):
|
||||
"""Update a materialized view's options."""
|
||||
await self._inner.alter_materialized_view(name, auto_refresh)
|
||||
|
||||
async def drop_materialized_view(self, name: str):
|
||||
"""Drop a materialized view definition."""
|
||||
await self._inner.drop_materialized_view(name)
|
||||
|
||||
async def list_materialized_views(self):
|
||||
"""List registered materialized view definitions."""
|
||||
return await self._inner.list_materialized_views()
|
||||
|
||||
async def list_jobs(self):
|
||||
"""List inflight server-side jobs across the database's tables."""
|
||||
return await self._inner.list_jobs()
|
||||
|
||||
async def get_job(self, job_id: str, table: "str | None" = None):
|
||||
"""Look up one server-side job by id (the wait()/status poll path).
|
||||
``table`` (the job's table) enables an O(1) server-side lookup.
|
||||
Returns the job's status, or None if unknown / no longer active."""
|
||||
return await self._inner.get_job(job_id, table)
|
||||
|
||||
async def cancel_job(self, job_id: str) -> bool:
|
||||
"""Cancel an inflight server-side job by id (CANCEL JOB).
|
||||
|
||||
Returns True if a matching inflight job was found and flagged for
|
||||
cancellation, False otherwise (best-effort).
|
||||
"""
|
||||
return await self._inner.cancel_job(job_id)
|
||||
|
||||
async def job_history(self, job_id: "str | None" = None):
|
||||
"""Durable history of completed server-side jobs (SHOW JOB HISTORY).
|
||||
|
||||
Reads each table's durable job-history store. Pass ``job_id`` to narrow
|
||||
to a single job. Unlike :meth:`list_jobs` (live, inflight) these are the
|
||||
terminal records, with created/updated/completed timestamps.
|
||||
"""
|
||||
return await self._inner.job_history(job_id)
|
||||
|
||||
async def errors(self, job_id: "str | None" = None, table: "str | None" = None):
|
||||
"""Per-row UDF errors recorded by ``error_policy=skip`` (SHOW ERRORS).
|
||||
|
||||
Optionally filtered by ``job_id`` and/or ``table``.
|
||||
"""
|
||||
return await self._inner.errors(job_id, table)
|
||||
|
||||
async def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
|
||||
@@ -81,6 +81,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
warnings.warn(
|
||||
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.pooling_strategy = None
|
||||
|
||||
|
||||
177
python/python/lancedb/lineage.py
Normal file
177
python/python/lancedb/lineage.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
"""Client-side model of derived-compute lineage.
|
||||
|
||||
`Connection.lineage()` / `Table.lineage()` / `MaterializedView.lineage()` return
|
||||
a `Lineage`: the graph of what a column or materialized view derives from
|
||||
(upstream), what derives from it (downstream), and -- for each derived column --
|
||||
the function that produced it, the version it was produced with, and whether
|
||||
that is stale relative to the function the registry now holds.
|
||||
|
||||
The server returns this as JSON (the wire contract); these classes deserialize
|
||||
it. Nothing here talks to the server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionRef:
|
||||
"""The function that produced a derived column, with version + location."""
|
||||
|
||||
name: str
|
||||
#: Version that produced the data (stamped at compute time), if known.
|
||||
as_computed_version: Optional[str] = None
|
||||
#: Version the registry currently holds for this function name.
|
||||
current_version: Optional[str] = None
|
||||
#: True when the column was produced by an older function than the registry
|
||||
#: now holds -- i.e. silently stale; re-refresh to catch up.
|
||||
stale_vs_current: bool = False
|
||||
language: Optional[str] = None
|
||||
docker_image: Optional[str] = None
|
||||
env_digest: Optional[str] = None
|
||||
code_uri: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def _from(cls, d: dict) -> "FunctionRef":
|
||||
return cls(
|
||||
name=d["name"],
|
||||
as_computed_version=d.get("as_computed_version"),
|
||||
current_version=d.get("current_version"),
|
||||
stale_vs_current=d.get("stale_vs_current", False),
|
||||
language=d.get("language"),
|
||||
docker_image=d.get("docker_image"),
|
||||
env_digest=d.get("env_digest"),
|
||||
code_uri=d.get("code_uri"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
"""A lineage node: a table, view, column, or function."""
|
||||
|
||||
kind: str # "table" | "view" | "column" | "function"
|
||||
id: str # "table", "table.column", or "fn:name@version"
|
||||
table: Optional[str] = None
|
||||
function: Optional[FunctionRef] = None
|
||||
|
||||
@classmethod
|
||||
def _from(cls, d: dict) -> "Node":
|
||||
fn = d.get("function")
|
||||
return cls(
|
||||
kind=d["kind"],
|
||||
id=d["id"],
|
||||
table=d.get("table"),
|
||||
function=FunctionRef._from(fn) if fn else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""`downstream` depends on `upstream`, produced by `via` (a function name,
|
||||
or None for a passthrough)."""
|
||||
|
||||
downstream: str
|
||||
upstream: str
|
||||
via: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def _from(cls, d: dict) -> "Edge":
|
||||
return cls(downstream=d["downstream"], upstream=d["upstream"], via=d.get("via"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Lineage:
|
||||
"""A derived-compute lineage graph (nodes + labeled edges)."""
|
||||
|
||||
target: str
|
||||
nodes: List[Node] = field(default_factory=list)
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: Union[str, bytes, dict]) -> "Lineage":
|
||||
d = json.loads(raw) if isinstance(raw, (str, bytes)) else raw
|
||||
return cls(
|
||||
target=d.get("target", ""),
|
||||
nodes=[Node._from(n) for n in d.get("nodes", [])],
|
||||
edges=[Edge._from(e) for e in d.get("edges", [])],
|
||||
)
|
||||
|
||||
def functions(self) -> List[FunctionRef]:
|
||||
"""The function nodes in the graph."""
|
||||
return [n.function for n in self.nodes if n.function is not None]
|
||||
|
||||
def stale(self) -> List[FunctionRef]:
|
||||
"""Functions whose as-computed version is behind the current registry
|
||||
version -- the columns they produced are silently out of date."""
|
||||
return [f for f in self.functions() if f.stale_vs_current]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def prune(d: dict) -> dict:
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
return {
|
||||
"target": self.target,
|
||||
"nodes": [
|
||||
prune(
|
||||
{
|
||||
"kind": n.kind,
|
||||
"id": n.id,
|
||||
"table": n.table,
|
||||
"function": prune(vars(n.function)) if n.function else None,
|
||||
}
|
||||
)
|
||||
for n in self.nodes
|
||||
],
|
||||
"edges": [prune(vars(e)) for e in self.edges],
|
||||
}
|
||||
|
||||
def to_graphviz(self) -> str:
|
||||
"""Graphviz DOT for the lineage DAG: columns/tables as nodes, function
|
||||
names on edges, drift edges dashed + red."""
|
||||
stale_names = {f.name for f in self.stale()}
|
||||
out = [
|
||||
"digraph lineage {",
|
||||
" rankdir=LR;",
|
||||
' node [fontname="monospace"];',
|
||||
]
|
||||
for n in self.nodes:
|
||||
if n.kind == "function":
|
||||
continue
|
||||
shape = "ellipse" if n.kind in ("table", "view") else "box"
|
||||
out.append(f' "{n.id}" [shape={shape}];')
|
||||
for e in self.edges:
|
||||
attrs = ""
|
||||
if e.via:
|
||||
if e.via in stale_names:
|
||||
attrs = f' [label="{e.via}" color=red style=dashed]'
|
||||
else:
|
||||
attrs = f' [label="{e.via}"]'
|
||||
out.append(f' "{e.upstream}" -> "{e.downstream}"{attrs};')
|
||||
out.append("}")
|
||||
return "\n".join(out)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
warn = ""
|
||||
drift = self.stale()
|
||||
if drift:
|
||||
names = ", ".join(sorted({f.name for f in drift}))
|
||||
warn = (
|
||||
f'<p style="color:#b00000"><b>stale vs current:</b> {names} '
|
||||
"(re-refresh to catch up)</p>"
|
||||
)
|
||||
rows = "".join(
|
||||
f"<tr><td><code>{e.downstream}</code></td>"
|
||||
f"<td>← {e.via or ''}</td>"
|
||||
f"<td><code>{e.upstream}</code></td></tr>"
|
||||
for e in self.edges
|
||||
)
|
||||
return (
|
||||
f"<b>lineage: <code>{self.target}</code></b>{warn}"
|
||||
"<table><tr><th>derived</th><th>via</th><th>from</th></tr>"
|
||||
f"{rows}</table>"
|
||||
)
|
||||
@@ -373,6 +373,19 @@ def _convert_pyarrow_schema_to_json(schema: pa.Schema) -> JsonArrowSchema:
|
||||
return JsonArrowSchema(fields=fields, metadata=meta)
|
||||
|
||||
|
||||
def _builds_namespace_natively(
|
||||
namespace_client_impl: Optional[str],
|
||||
namespace_client_properties: Optional[Dict[str, str]],
|
||||
) -> bool:
|
||||
"""Whether ``connect_namespace_client`` builds the namespace client natively
|
||||
in Rust (installing the read-freshness context provider) rather than wrapping
|
||||
the pre-built Python client.
|
||||
|
||||
Must mirror Rust ``build_namespace_natively`` in ``python/src/connection.rs``.
|
||||
"""
|
||||
return namespace_client_impl == "rest" and bool(namespace_client_properties)
|
||||
|
||||
|
||||
class LanceNamespaceDBConnection(DBConnection):
|
||||
"""
|
||||
A LanceDB connection that uses a namespace for table management.
|
||||
@@ -432,6 +445,13 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
)
|
||||
self._namespace_client_impl = namespace_client_impl
|
||||
self._namespace_client_properties = namespace_client_properties
|
||||
# When the namespace client is built natively (see Rust
|
||||
# ``build_namespace_natively``), the underlying Rust table performs
|
||||
# QueryTable pushdown through the read-freshness context provider, which
|
||||
# the pure-Python ``query_table`` path bypasses.
|
||||
self._route_pushdown_to_rust = _builds_namespace_natively(
|
||||
namespace_client_impl, namespace_client_properties
|
||||
)
|
||||
self._inner = AsyncConnection(
|
||||
_connect_namespace_client(
|
||||
namespace_client,
|
||||
@@ -543,6 +563,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
_async=async_table,
|
||||
)
|
||||
|
||||
@@ -580,6 +601,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
_async=async_table,
|
||||
)
|
||||
if branch is not None:
|
||||
@@ -875,6 +897,8 @@ class AsyncLanceNamespaceDBConnection:
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
namespace_client_pushdown_operations: Optional[List[str]] = None,
|
||||
namespace_client_impl: Optional[str] = None,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize an async namespace-based LanceDB connection.
|
||||
@@ -900,6 +924,12 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace.create_table() instead of using declare_table + local write.
|
||||
|
||||
Default is None (no pushdown, all operations run locally).
|
||||
namespace_client_impl : Optional[str]
|
||||
The namespace implementation name used to create this connection.
|
||||
Required (with ``namespace_client_properties``) for the Rust client to
|
||||
be built natively and install the read-freshness provider.
|
||||
namespace_client_properties : Optional[Dict[str, str]]
|
||||
The namespace properties used to create this connection.
|
||||
"""
|
||||
self._namespace_client = namespace_client
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
@@ -908,6 +938,14 @@ class AsyncLanceNamespaceDBConnection:
|
||||
self._namespace_client_pushdown_operations = set(
|
||||
namespace_client_pushdown_operations or []
|
||||
)
|
||||
self._namespace_client_impl = namespace_client_impl
|
||||
self._namespace_client_properties = namespace_client_properties
|
||||
# See LanceNamespaceDBConnection: when built natively the Rust table runs
|
||||
# QueryTable pushdown through the read-freshness provider, so defer to it
|
||||
# rather than the urllib3 client (which omits x-lancedb-min-timestamp).
|
||||
self._route_pushdown_to_rust = _builds_namespace_natively(
|
||||
namespace_client_impl, namespace_client_properties
|
||||
)
|
||||
self._inner = AsyncConnection(
|
||||
_connect_namespace_client(
|
||||
namespace_client,
|
||||
@@ -921,8 +959,8 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace_client_pushdown_operations=(
|
||||
list(self._namespace_client_pushdown_operations)
|
||||
),
|
||||
namespace_client_impl=None,
|
||||
namespace_client_properties=None,
|
||||
namespace_client_impl=namespace_client_impl,
|
||||
namespace_client_properties=namespace_client_properties,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -992,6 +1030,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
)
|
||||
|
||||
async def open_table(
|
||||
@@ -1029,6 +1068,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
)
|
||||
|
||||
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
|
||||
@@ -1387,4 +1427,6 @@ def connect_namespace_async(
|
||||
storage_options=storage_options,
|
||||
session=session,
|
||||
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
|
||||
namespace_client_impl=namespace_client_impl,
|
||||
namespace_client_properties=namespace_client_properties,
|
||||
)
|
||||
|
||||
@@ -48,6 +48,14 @@ class PermutationBuilder:
|
||||
By default, the permutation builder will create a single split that contains all
|
||||
rows in the same order as the base table.
|
||||
"""
|
||||
if not hasattr(table, "_inner"):
|
||||
raise TypeError(
|
||||
f"PermutationBuilder requires a local LanceTable, "
|
||||
f"got {type(table).__name__}. "
|
||||
"The permutation API is not supported on remote tables. "
|
||||
"Remote tables connect to LanceDB Cloud or Enterprise and do not have "
|
||||
"direct access to the underlying Lance dataset needed for permutations."
|
||||
)
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def split_random(
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import List, Optional
|
||||
from lancedb import __version__
|
||||
|
||||
from .header import HeaderProvider
|
||||
from .oauth import OAuthConfig, OAuthFlowType
|
||||
|
||||
__all__ = [
|
||||
"TimeoutConfig",
|
||||
@@ -16,6 +17,8 @@ __all__ = [
|
||||
"TlsConfig",
|
||||
"ClientConfig",
|
||||
"HeaderProvider",
|
||||
"OAuthConfig",
|
||||
"OAuthFlowType",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -124,6 +124,7 @@ class RemoteDBConnection(DBConnection):
|
||||
"request_thread_pool is no longer used and will be removed in "
|
||||
"a future release.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if connection_timeout is not None:
|
||||
@@ -132,6 +133,7 @@ class RemoteDBConnection(DBConnection):
|
||||
"release. Please use client_config.timeout_config.connect_timeout "
|
||||
"instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
client_config.timeout_config.connect_timeout = timedelta(
|
||||
seconds=connection_timeout
|
||||
@@ -142,6 +144,7 @@ class RemoteDBConnection(DBConnection):
|
||||
"read_timeout is deprecated and will be removed in a future release. "
|
||||
"Please use client_config.timeout_config.read_timeout instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
|
||||
|
||||
|
||||
75
python/python/lancedb/remote/oauth.py
Normal file
75
python/python/lancedb/remote/oauth.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class OAuthFlowType(str, Enum):
|
||||
"""OAuth authentication flow types."""
|
||||
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
"""Client Credentials grant (service-to-service / M2M)."""
|
||||
|
||||
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
|
||||
"""Azure Managed Identity via IMDS."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthConfig:
|
||||
"""OAuth configuration for LanceDB authentication.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
This config is passed through to Rust via PyO3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
issuer_url : str
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
|
||||
client_id : str
|
||||
Application / Client ID.
|
||||
scopes : List[str]
|
||||
OAuth scopes to request.
|
||||
For Azure managed identity, exactly one scope or resource is required.
|
||||
For example: ``["api://{app_id}/.default"]``
|
||||
flow : OAuthFlowType
|
||||
Authentication flow to use. Default: CLIENT_CREDENTIALS.
|
||||
client_secret : Optional[str]
|
||||
Client secret (required for CLIENT_CREDENTIALS).
|
||||
managed_identity_client_id : Optional[str]
|
||||
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
|
||||
refresh_buffer_secs : Optional[int]
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Client Credentials (service-to-service):
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... client_secret="secret",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... )
|
||||
|
||||
Azure Managed Identity:
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
|
||||
... )
|
||||
"""
|
||||
|
||||
issuer_url: str
|
||||
client_id: str
|
||||
scopes: List[str]
|
||||
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
|
||||
client_secret: Optional[str] = field(default=None, repr=False)
|
||||
managed_identity_client_id: Optional[str] = None
|
||||
refresh_buffer_secs: Optional[int] = None
|
||||
@@ -13,10 +13,14 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
Literal,
|
||||
overload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..udf import JobHandle
|
||||
import warnings
|
||||
|
||||
from lancedb import __version__
|
||||
@@ -845,7 +849,8 @@ class RemoteTable(Table):
|
||||
"""
|
||||
warnings.warn(
|
||||
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically cleaned up and optimized."
|
||||
"Tables are automatically cleaned up and optimized.",
|
||||
stacklevel=2,
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -857,7 +862,8 @@ class RemoteTable(Table):
|
||||
"""
|
||||
warnings.warn(
|
||||
"compact_files() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically compacted and optimized."
|
||||
"Tables are automatically compacted and optimized.",
|
||||
stacklevel=2,
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -874,15 +880,150 @@ class RemoteTable(Table):
|
||||
"""
|
||||
warnings.warn(
|
||||
"optimize() is a no-op on LanceDB Cloud. "
|
||||
"Indices are optimized automatically."
|
||||
"Indices are optimized automatically.",
|
||||
stacklevel=2,
|
||||
)
|
||||
pass
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
def add_columns(
|
||||
self,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
computed: Optional[Dict[str, tuple]] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
result = None
|
||||
if transforms is not None:
|
||||
result = LOOP.run(self._table.add_columns(transforms))
|
||||
if computed:
|
||||
LOOP.run(self._table.add_columns(computed=computed))
|
||||
return result
|
||||
|
||||
def refresh_column(
|
||||
self,
|
||||
columns,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> "JobHandle":
|
||||
"""Trigger recompute of computed columns (REFRESH COLUMN).
|
||||
|
||||
The expression is resolved server-side from each column's stored
|
||||
binding; columns bound to the same struct-returning function
|
||||
refresh together. Returns a `JobHandle` to wait on, poll, or cancel
|
||||
(``tbl.refresh_column("c").wait()``). Server-backed feature
|
||||
(LanceDB Enterprise / Cloud).
|
||||
|
||||
num_workers / max_workers / batch_size / priority are per-refresh
|
||||
scheduling knobs (how to run THIS refresh) and override any default
|
||||
the function carries. `priority` is a Kueue tier
|
||||
(training | interactive | backfill).
|
||||
"""
|
||||
from ..udf import JobHandle
|
||||
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
job_id = LOOP.run(
|
||||
self._table.refresh_column(
|
||||
list(columns),
|
||||
where=where,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
return JobHandle(self._job_conn(), job_id)
|
||||
|
||||
def lineage(self, column=None, *, direction=None, depth=None):
|
||||
"""Derived-compute lineage of this table, or one of its columns:
|
||||
upstream sources, downstream dependents, and the function version +
|
||||
location that produced each derived column (with a drift flag). Returns
|
||||
a `Lineage`. See `Connection.lineage`."""
|
||||
return self._job_conn().lineage(
|
||||
self._name, column, direction=direction, depth=depth
|
||||
)
|
||||
|
||||
def _job_conn(self):
|
||||
"""A client connection for polling jobs this table spawns. Built lazily
|
||||
from the table's serialized connection state and cached (not pickled --
|
||||
a forked/unpickled table rebuilds it on next use)."""
|
||||
from lancedb import deserialize_conn
|
||||
|
||||
conn = getattr(self, "_job_conn_cache", None)
|
||||
if conn is None:
|
||||
conn = deserialize_conn(self._serialized_connection_state())
|
||||
self._job_conn_cache = conn
|
||||
return conn
|
||||
|
||||
def load_columns(
|
||||
self,
|
||||
source: Union[str, Iterable[str]],
|
||||
pk: str,
|
||||
columns: Union[Iterable[str], Dict[str, str]],
|
||||
*,
|
||||
source_format: str = "parquet",
|
||||
source_pk: Optional[str] = None,
|
||||
on_missing: str = "carry",
|
||||
source_storage_options: Optional[Dict[str, str]] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
commit_granularity: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Fill existing columns from an external source by primary-key join.
|
||||
|
||||
The distributed-job equivalent of Geneva's ``Table.load_columns()``:
|
||||
imports precomputed values (e.g. embeddings) from Parquet/Lance/IPC into
|
||||
this table, matching on a primary key. Returns the load job id.
|
||||
Server-backed feature (LanceDB Enterprise / Cloud).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source: str | list[str]
|
||||
One source URI or a list of URIs.
|
||||
pk: str
|
||||
Destination primary-key column. Also the source key unless
|
||||
``source_pk`` is given.
|
||||
columns: list[str] | dict[str, str]
|
||||
Value columns to load. A list loads same-named columns; a dict maps
|
||||
``{target: source}``.
|
||||
source_format: str
|
||||
``"parquet"`` (default), ``"lance"``, or ``"ipc"``.
|
||||
source_pk: str, optional
|
||||
Source primary-key column when it differs from ``pk``.
|
||||
on_missing: str
|
||||
Behavior for destination rows with no source match:
|
||||
``"carry"`` (default, keep existing), ``"null"``, or ``"error"``.
|
||||
"""
|
||||
if isinstance(source, str):
|
||||
source = [source]
|
||||
if isinstance(columns, dict):
|
||||
mappings = [(target, src) for target, src in columns.items()]
|
||||
else:
|
||||
mappings = [(c, None) for c in columns]
|
||||
return LOOP.run(
|
||||
self._table.load_columns(
|
||||
list(source),
|
||||
source_format,
|
||||
pk,
|
||||
mappings,
|
||||
source_key=source_pk,
|
||||
source_storage_options=source_storage_options,
|
||||
on_missing=on_missing,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
commit_granularity=commit_granularity,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
|
||||
def alter_columns(
|
||||
self, *alterations: Iterable[Dict[str, str]]
|
||||
|
||||
@@ -702,6 +702,24 @@ def _normalize_progress(progress):
|
||||
return progress, False
|
||||
|
||||
|
||||
def _computed_groups(computed):
|
||||
"""Group computed columns by expression, preserving declaration order
|
||||
(struct-returning functions need their columns adjacent so schema order
|
||||
matches field order). Accepts the ergonomic forms -- `fn("col")` values
|
||||
and tuple keys for struct fan-out -- via `_normalize_computed`."""
|
||||
from .udf import _normalize_computed
|
||||
|
||||
groups = []
|
||||
for name, (sql_type, expression) in _normalize_computed(computed).items():
|
||||
for expr, cols in groups:
|
||||
if expr == expression:
|
||||
cols.append((name, sql_type))
|
||||
break
|
||||
else:
|
||||
groups.append((expression, [(name, sql_type)]))
|
||||
return groups
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
"""
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
@@ -807,6 +825,59 @@ class Table(ABC):
|
||||
"""The number of rows in this Table"""
|
||||
return self.count_rows(None)
|
||||
|
||||
def add_computed_column(
|
||||
self,
|
||||
columns,
|
||||
fn,
|
||||
args: Optional[List[str]] = None,
|
||||
types=None,
|
||||
) -> None:
|
||||
"""Declare computed column(s) bound to a UDF -- no compute happens
|
||||
here (the agent fills them lazily, or refresh_column() triggers a run).
|
||||
|
||||
.. deprecated::
|
||||
A computed column is an expression over a registered function, so
|
||||
bind it as one: ``add_columns(computed={"vec": embed("data")})``.
|
||||
``embed("data")`` applies the function to the `data` column and
|
||||
infers the type from the function's return signature -- the
|
||||
function never couples to a particular column. Prefer that form.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"add_computed_column is deprecated; use add_columns(computed="
|
||||
'{"vec": embed("data")}).',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from .udf import Udf, struct_field_types
|
||||
|
||||
multi = isinstance(columns, (tuple, list))
|
||||
if isinstance(fn, Udf):
|
||||
expr = fn.expression(*(args or []))
|
||||
if types is None:
|
||||
if multi:
|
||||
if not fn.returns.upper().startswith("STRUCT"):
|
||||
raise ValueError(
|
||||
"several columns need a STRUCT-returning function"
|
||||
)
|
||||
types = struct_field_types(fn.returns)
|
||||
else:
|
||||
types = fn.returns
|
||||
else:
|
||||
if types is None:
|
||||
raise ValueError("pass types= when fn is a name string")
|
||||
expr = f"{fn}({', '.join(args or [])})"
|
||||
if multi:
|
||||
if len(types) != len(columns):
|
||||
raise ValueError(
|
||||
f"{len(columns)} columns but {len(types)} output types"
|
||||
)
|
||||
computed = {c: (t, expr) for c, t in zip(columns, types)}
|
||||
else:
|
||||
computed = {columns: (types, expr)}
|
||||
self.add_columns(computed=computed)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
@@ -2022,6 +2093,7 @@ class LanceTable(Table):
|
||||
namespace_client: Optional[Any] = None,
|
||||
managed_versioning: Optional[bool] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
_async: AsyncTable = None,
|
||||
):
|
||||
if namespace_path is None:
|
||||
@@ -2031,6 +2103,14 @@ class LanceTable(Table):
|
||||
self._location = location # Store location for use in _dataset_path
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
# When the connection built the namespace client natively (e.g. an
|
||||
# enterprise "rest" connection), the underlying Rust table already
|
||||
# executes QueryTable pushdown itself -- and, unlike this Python urllib3
|
||||
# path, it routes through the read-freshness context provider that emits
|
||||
# the ``x-lancedb-min-timestamp`` header. So we must defer pushdown to
|
||||
# Rust instead of calling the Python ``namespace_client.query_table``
|
||||
# directly, or reads silently bypass read-freshness (stale results).
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
if _async is not None:
|
||||
self._table = _async
|
||||
else:
|
||||
@@ -2241,6 +2321,7 @@ class LanceTable(Table):
|
||||
namespace_path=self._namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
location=self._location,
|
||||
_async=async_table,
|
||||
)
|
||||
@@ -2391,8 +2472,11 @@ class LanceTable(Table):
|
||||
Returns
|
||||
-------
|
||||
pa.Table"""
|
||||
if _should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
if (
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
):
|
||||
return self._execute_query(Query()).read_all()
|
||||
|
||||
@@ -3344,6 +3428,7 @@ class LanceTable(Table):
|
||||
location: Optional[str] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -3406,21 +3491,24 @@ class LanceTable(Table):
|
||||
self._location = location
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
|
||||
if data_storage_version is not None:
|
||||
warnings.warn(
|
||||
"setting data_storage_version directly on create_table is deprecated. ",
|
||||
"setting data_storage_version directly on create_table is deprecated. "
|
||||
"Use database_options instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if storage_options is None:
|
||||
storage_options = {}
|
||||
storage_options["new_table_data_storage_version"] = data_storage_version
|
||||
if enable_v2_manifest_paths is not None:
|
||||
warnings.warn(
|
||||
"setting enable_v2_manifest_paths directly on create_table is ",
|
||||
"setting enable_v2_manifest_paths directly on create_table is "
|
||||
"deprecated. Use database_options instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if storage_options is None:
|
||||
storage_options = {}
|
||||
@@ -3517,6 +3605,7 @@ class LanceTable(Table):
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
and self.current_branch() is None
|
||||
):
|
||||
from lancedb.namespace import _execute_server_side_query
|
||||
@@ -3692,9 +3781,68 @@ class LanceTable(Table):
|
||||
return LOOP.run(self._table.index_stats(index_name))
|
||||
|
||||
def add_columns(
|
||||
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
self,
|
||||
transforms: Dict[str, str]
|
||||
| pa.field
|
||||
| List[pa.field]
|
||||
| pa.Schema
|
||||
| None = None,
|
||||
*,
|
||||
computed: Optional[Dict] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
result = None
|
||||
if transforms is not None:
|
||||
result = LOOP.run(self._table.add_columns(transforms))
|
||||
if computed:
|
||||
# computed binds an expression over a registered function to a
|
||||
# column: {col: fn("input_col")} -- fn("input_col") yields the
|
||||
# expression and carries the inferred type; a tuple key fans a
|
||||
# STRUCT return out to several columns. Declares the binding only;
|
||||
# the server fills the values (server-backed). The legacy
|
||||
# {col: (sql_type, expression)} tuple form is still accepted.
|
||||
result_unused = LOOP.run(self._table.add_columns(computed=computed))
|
||||
del result_unused
|
||||
return result
|
||||
|
||||
def refresh_column(
|
||||
self,
|
||||
columns,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> "JobHandle":
|
||||
"""Trigger recompute of computed columns (REFRESH COLUMN).
|
||||
|
||||
The expression is resolved server-side from each column's stored
|
||||
binding; columns bound to the same struct-returning function
|
||||
refresh together. Returns a `JobHandle` to wait on, poll, or cancel
|
||||
(``tbl.refresh_column("col").wait()``) -- mirrors
|
||||
`MaterializedView.refresh()`. Server-backed feature (LanceDB
|
||||
Enterprise / Cloud).
|
||||
|
||||
num_workers / max_workers / batch_size / priority are per-refresh
|
||||
scheduling knobs (how to run THIS refresh) and override any default
|
||||
the function carries. `priority` is a Kueue tier
|
||||
(training | interactive | backfill).
|
||||
"""
|
||||
from .udf import JobHandle
|
||||
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
job_id = LOOP.run(
|
||||
self._table.refresh_column(
|
||||
list(columns),
|
||||
where=where,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
return JobHandle(self._conn, job_id, table=self.name)
|
||||
|
||||
def alter_columns(
|
||||
self, *alterations: Iterable[Dict[str, str]]
|
||||
@@ -4258,6 +4406,7 @@ class AsyncTable:
|
||||
namespace_path: Optional[List[str]] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
):
|
||||
"""Create a new AsyncTable object.
|
||||
|
||||
@@ -4270,6 +4419,9 @@ class AsyncTable:
|
||||
self._namespace_path = namespace_path or []
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
# See LanceTable.__init__: defer QueryTable pushdown to Rust (which emits
|
||||
# the read-freshness header) for natively-built namespace clients.
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
|
||||
def _set_namespace_context(
|
||||
self,
|
||||
@@ -4277,10 +4429,12 @@ class AsyncTable:
|
||||
namespace_path: Optional[List[str]] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
) -> "AsyncTable":
|
||||
self._namespace_path = namespace_path or []
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
@@ -4490,8 +4644,11 @@ class AsyncTable:
|
||||
-------
|
||||
pa.Table
|
||||
"""
|
||||
if _should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
if (
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
):
|
||||
return (await self._execute_query(Query())).read_all()
|
||||
|
||||
@@ -5175,8 +5332,11 @@ class AsyncTable:
|
||||
batch_size: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
if _should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
if (
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
):
|
||||
from lancedb.namespace import _execute_server_side_query
|
||||
|
||||
@@ -5360,9 +5520,44 @@ class AsyncTable:
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def refresh_column(
|
||||
self,
|
||||
columns,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Trigger recompute of computed columns (REFRESH COLUMN).
|
||||
Returns the refresh job id. Server-backed feature.
|
||||
|
||||
num_workers / max_workers / batch_size / priority are per-refresh
|
||||
scheduling knobs (how to run THIS refresh); they override any default
|
||||
the function carries. `priority` is a Kueue tier
|
||||
(training | interactive | backfill)."""
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
return await self._inner.refresh_column(
|
||||
list(columns),
|
||||
where_clause=where,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
async def add_columns(
|
||||
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
) -> AddColumnsResult:
|
||||
self,
|
||||
transforms: dict[str, str]
|
||||
| pa.field
|
||||
| List[pa.field]
|
||||
| pa.Schema
|
||||
| None = None,
|
||||
*,
|
||||
computed: Optional[Dict] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -5381,6 +5576,7 @@ class AsyncTable:
|
||||
version: the new version number of the table after adding columns.
|
||||
|
||||
"""
|
||||
result = None
|
||||
if isinstance(transforms, pa.Field):
|
||||
transforms = [transforms]
|
||||
if isinstance(transforms, list) and all(
|
||||
@@ -5388,9 +5584,69 @@ class AsyncTable:
|
||||
):
|
||||
transforms = pa.schema(transforms)
|
||||
if isinstance(transforms, pa.Schema):
|
||||
return await self._inner.add_columns_with_schema(transforms)
|
||||
result = await self._inner.add_columns_with_schema(transforms)
|
||||
elif transforms is not None:
|
||||
result = await self._inner.add_columns(list(transforms.items()))
|
||||
if computed:
|
||||
# computed binds an expression over a registered function to a
|
||||
# column: {col: fn("input_col")} -- fn("input_col") yields the
|
||||
# expression and carries the inferred type; a tuple key fans a
|
||||
# STRUCT return out to several columns. Declares the binding only;
|
||||
# the server fills the values (server-backed). The legacy
|
||||
# {col: (sql_type, expression)} tuple form is still accepted.
|
||||
for expression, cols in _computed_groups(computed):
|
||||
await self._inner.add_computed_columns(cols, expression)
|
||||
return result
|
||||
|
||||
async def add_computed_column(
|
||||
self,
|
||||
columns,
|
||||
fn,
|
||||
args: Optional[List[str]] = None,
|
||||
types=None,
|
||||
) -> None:
|
||||
"""Declare computed column(s) bound to a UDF (async).
|
||||
|
||||
.. deprecated::
|
||||
Use ``add_columns(computed={"col": fn("input_col")})`` -- a computed
|
||||
column is an expression over a registered function, so bind it that
|
||||
way instead of coupling the UDF to the column here.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"add_computed_column is deprecated; use add_columns(computed="
|
||||
'{"col": fn("input_col")}).',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from .udf import Udf, struct_field_types
|
||||
|
||||
multi = isinstance(columns, (tuple, list))
|
||||
if isinstance(fn, Udf):
|
||||
expr = fn.expression(*(args or []))
|
||||
if types is None:
|
||||
if multi:
|
||||
if not fn.returns.upper().startswith("STRUCT"):
|
||||
raise ValueError(
|
||||
"several columns need a STRUCT-returning function"
|
||||
)
|
||||
types = struct_field_types(fn.returns)
|
||||
else:
|
||||
types = fn.returns
|
||||
else:
|
||||
return await self._inner.add_columns(list(transforms.items()))
|
||||
if types is None:
|
||||
raise ValueError("pass types= when fn is a name string")
|
||||
expr = f"{fn}({', '.join(args or [])})"
|
||||
if multi:
|
||||
if len(types) != len(columns):
|
||||
raise ValueError(
|
||||
f"{len(columns)} columns but {len(types)} output types"
|
||||
)
|
||||
computed = {c: (t, expr) for c, t in zip(columns, types)}
|
||||
else:
|
||||
computed = {columns: (types, expr)}
|
||||
await self.add_columns(computed=computed)
|
||||
|
||||
async def alter_columns(
|
||||
self, *alterations: Iterable[dict[str, Any]]
|
||||
@@ -5662,6 +5918,7 @@ class AsyncTable:
|
||||
"The 'retrain' parameter is deprecated and will be removed in a "
|
||||
"future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return await self._inner.optimize(
|
||||
|
||||
753
python/python/lancedb/udf.py
Normal file
753
python/python/lancedb/udf.py
Normal file
@@ -0,0 +1,753 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
"""UDF authoring for LanceDB derived compute (server-backed).
|
||||
|
||||
`@udf` / `@table_udf` turn a plain Python function into a registrable
|
||||
server-side UDF: a cloudpickled (or source) body, a SQL signature inferred
|
||||
from type hints, and the runtime options (pip deps, GPUs, batching, ...).
|
||||
Register and use them through the existing connection/table API:
|
||||
|
||||
import lancedb
|
||||
from lancedb import udf, table_udf
|
||||
|
||||
db = lancedb.connect("db://my_db", api_key="...", host_override="...")
|
||||
|
||||
@udf(pip=["torch>=2.0"], num_gpus=1)
|
||||
def embed(text: str) -> list[float]:
|
||||
return model.encode(text).tolist()
|
||||
|
||||
db.create_function(embed) # CREATE FUNCTION (once)
|
||||
tbl = db.open_table("docs")
|
||||
tbl.add_columns(computed={"vec": embed("text")}) # bind embed(text) -> vec
|
||||
tbl.refresh_column("vec").wait() # materialize (returns a JobHandle)
|
||||
view = db.create_materialized_view("chunks", tbl, ["id", chunk_fn])
|
||||
|
||||
`embed("text")` applies the registered function to the `text` column and yields
|
||||
the expression `embed(text)`; the function itself stays decoupled from any
|
||||
column, so the same `embed` works on any column or table.
|
||||
|
||||
These operations are server-backed (LanceDB Enterprise / Cloud); the
|
||||
decorator itself works locally (define + call), only registration needs a
|
||||
remote connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import typing
|
||||
|
||||
# -- type hints -> SQL type strings -------------------------------------
|
||||
|
||||
_SCALARS = {
|
||||
int: "BIGINT",
|
||||
# Pragmatic default for ML workloads: python float maps to FLOAT
|
||||
# (Float32). Use an explicit `returns=` for DOUBLE.
|
||||
float: "FLOAT",
|
||||
str: "VARCHAR",
|
||||
bool: "BOOLEAN",
|
||||
bytes: "BLOB",
|
||||
}
|
||||
|
||||
|
||||
class TypeInferenceError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
def sql_type(hint) -> str:
|
||||
"""SQL type string for a python type hint."""
|
||||
if hint in _SCALARS:
|
||||
return _SCALARS[hint]
|
||||
origin = typing.get_origin(hint)
|
||||
if origin in (list, typing.List):
|
||||
(item,) = typing.get_args(hint) or (None,)
|
||||
if item in _SCALARS:
|
||||
return f"{_SCALARS[item]}[]"
|
||||
raise TypeInferenceError(
|
||||
f"unsupported list item type {item!r}; use an explicit returns="
|
||||
)
|
||||
fields = _struct_fields(hint)
|
||||
if fields is not None:
|
||||
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
|
||||
return f"STRUCT({inner})"
|
||||
raise TypeInferenceError(
|
||||
f"cannot infer a SQL type for {hint!r}; pass an explicit type string"
|
||||
)
|
||||
|
||||
|
||||
def _struct_fields(hint):
|
||||
"""(name, hint) pairs for a TypedDict or dataclass, else None."""
|
||||
if dataclasses.is_dataclass(hint):
|
||||
return [(f.name, f.type) for f in dataclasses.fields(hint)]
|
||||
# TypedDict detection: a dict subclass with __annotations__.
|
||||
if (
|
||||
isinstance(hint, type)
|
||||
and issubclass(hint, dict)
|
||||
and typing.get_type_hints(hint)
|
||||
):
|
||||
return list(typing.get_type_hints(hint).items())
|
||||
return None
|
||||
|
||||
|
||||
def return_type(fn, override: "str | None", table: bool) -> str:
|
||||
"""SQL return type for a function: explicit override wins, else the
|
||||
return annotation. Table functions render as TABLE(...) and accept
|
||||
struct-shaped hints (TypedDict/dataclass, optionally list-wrapped)."""
|
||||
if override is not None:
|
||||
s = override.strip()
|
||||
if table and not s.upper().startswith("TABLE"):
|
||||
if s.upper().startswith("STRUCT"):
|
||||
return "TABLE" + s[len("STRUCT") :]
|
||||
raise TypeInferenceError(
|
||||
"a table function's returns= must be TABLE(...) or STRUCT(...)"
|
||||
)
|
||||
return s
|
||||
|
||||
hints = typing.get_type_hints(fn)
|
||||
ret = hints.get("return")
|
||||
if ret is None:
|
||||
raise TypeInferenceError(
|
||||
f"function {fn.__name__!r} needs a return annotation or returns="
|
||||
)
|
||||
if table:
|
||||
# Accept list[Row] / Row where Row is a TypedDict or dataclass.
|
||||
if typing.get_origin(ret) in (list, typing.List):
|
||||
(ret,) = typing.get_args(ret)
|
||||
fields = _struct_fields(ret)
|
||||
if fields is None:
|
||||
raise TypeInferenceError(
|
||||
"a table function must return rows shaped as a TypedDict or "
|
||||
"dataclass (optionally list-wrapped); or pass returns=..."
|
||||
)
|
||||
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
|
||||
return f"TABLE({inner})"
|
||||
return sql_type(ret)
|
||||
|
||||
|
||||
def param_types(fn) -> "list[tuple[str, str]]":
|
||||
"""(name, sql type) per parameter, from annotations. Each UDF
|
||||
parameter binds to a source column of the same name by default."""
|
||||
hints = typing.get_type_hints(fn)
|
||||
out = []
|
||||
for name, p in inspect.signature(fn).parameters.items():
|
||||
if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD):
|
||||
raise TypeInferenceError("*args/**kwargs are not supported in UDFs")
|
||||
hint = hints.get(name)
|
||||
if hint is None:
|
||||
raise TypeInferenceError(
|
||||
f"parameter {name!r} of {fn.__name__!r} needs a type annotation"
|
||||
)
|
||||
out.append((name, sql_type(hint)))
|
||||
return out
|
||||
|
||||
|
||||
# -- column expressions -------------------------------------------------
|
||||
|
||||
|
||||
class ColumnExpr(str):
|
||||
"""A computed-column expression produced by applying a registered
|
||||
function to column names, e.g. ``embed("data") -> "embed(data)"``.
|
||||
|
||||
It IS the expression string everywhere a string is expected (views, SQL,
|
||||
logging), and additionally carries the function's declared return type so
|
||||
``add_columns(computed=...)`` can declare the column without a hand-written
|
||||
type. ``field_types`` holds the per-field SQL types of a STRUCT return, for
|
||||
fanning one expression out to several columns.
|
||||
"""
|
||||
|
||||
data_type: "str | None"
|
||||
field_types: "list[str] | None"
|
||||
|
||||
def __new__(cls, expr: str, data_type=None, field_types=None):
|
||||
obj = super().__new__(cls, expr)
|
||||
obj.data_type = data_type
|
||||
obj.field_types = field_types
|
||||
return obj
|
||||
|
||||
|
||||
def _normalize_computed(computed: dict) -> dict:
|
||||
"""Normalize the user-facing ``computed=`` mapping to the canonical
|
||||
``{name: (sql_type, expression)}`` form.
|
||||
|
||||
Accepts, per entry:
|
||||
- value is a `ColumnExpr` (from ``fn("col")``): the column's SQL type
|
||||
comes from the function's return type -- no hand-written type needed. A
|
||||
tuple key (``("chunk", "idx")``) fans a STRUCT return out to one
|
||||
(type, expression) entry per field, in declared order.
|
||||
- value is a legacy ``(sql_type, expression)`` tuple: passed through (the
|
||||
escape hatch, e.g. bare-name function strings).
|
||||
"""
|
||||
out: dict = {}
|
||||
for key, val in computed.items():
|
||||
if isinstance(val, ColumnExpr):
|
||||
expr = str(val)
|
||||
if isinstance(key, (tuple, list)):
|
||||
if not val.field_types:
|
||||
raise ValueError(
|
||||
f"columns {tuple(key)} need a STRUCT-returning function; "
|
||||
f"{expr} returns a single value"
|
||||
)
|
||||
if len(val.field_types) != len(key):
|
||||
raise ValueError(
|
||||
f"{len(key)} columns but {len(val.field_types)} struct fields "
|
||||
f"in {expr}"
|
||||
)
|
||||
for name, t in zip(key, val.field_types):
|
||||
out[name] = (t, expr)
|
||||
else:
|
||||
if val.data_type is None:
|
||||
raise ValueError(f"cannot infer a type for {expr}; pass types=")
|
||||
out[key] = (val.data_type, expr)
|
||||
else:
|
||||
out[key] = val
|
||||
return out
|
||||
|
||||
|
||||
# -- the @udf / @table_udf decorators -----------------------------------
|
||||
|
||||
|
||||
class Udf:
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
*,
|
||||
returns: "str | None" = None,
|
||||
table: bool = False,
|
||||
name: "str | None" = None,
|
||||
pip: "list[str] | None" = None,
|
||||
pip_index_url: "str | None" = None,
|
||||
pip_extra_index_urls: "list[str] | None" = None,
|
||||
find_links: "list[str] | None" = None,
|
||||
requirements: "str | list[str] | None" = None,
|
||||
conda: "list[str] | None" = None,
|
||||
conda_channels: "list[str] | None" = None,
|
||||
env: "dict[str, str] | list[str] | None" = None,
|
||||
num_cpus: "int | None" = None,
|
||||
num_gpus: "int | None" = None,
|
||||
batch_size: "int | None" = None,
|
||||
timeout: "float | None" = None,
|
||||
error_policy: "str | None" = None,
|
||||
max_skip_ratio: "float | None" = None,
|
||||
retries: "int | None" = None,
|
||||
docker_image: "str | None" = None,
|
||||
description: "str | None" = None,
|
||||
prefer_source: bool = False,
|
||||
):
|
||||
functools.update_wrapper(self, fn)
|
||||
self.fn = fn
|
||||
self.name = name or fn.__name__
|
||||
self.table = table
|
||||
self.params = param_types(fn)
|
||||
self.returns = return_type(fn, returns, table)
|
||||
self.prefer_source = prefer_source
|
||||
self.options: "dict[str, str]" = {}
|
||||
if conda and (pip or requirements):
|
||||
raise ValueError("pass conda or pip/requirements, not both")
|
||||
if conda_channels and not conda:
|
||||
raise ValueError("conda_channels requires conda")
|
||||
if pip:
|
||||
self.options["pip"] = ",".join(pip)
|
||||
if pip_extra_index_urls:
|
||||
self.options["pip_extra_index_urls"] = ",".join(pip_extra_index_urls)
|
||||
if find_links:
|
||||
self.options["find_links"] = ",".join(find_links)
|
||||
if requirements:
|
||||
self.options["requirements"] = _format_requirements(requirements)
|
||||
if conda:
|
||||
self.options["conda"] = ",".join(conda)
|
||||
if conda_channels:
|
||||
self.options["conda_channels"] = ",".join(conda_channels)
|
||||
if env:
|
||||
self.options["env"] = _format_env(env)
|
||||
for key, val in [
|
||||
("pip_index_url", pip_index_url),
|
||||
("num_cpus", num_cpus),
|
||||
("num_gpus", num_gpus),
|
||||
("batch_size", batch_size),
|
||||
("timeout", timeout),
|
||||
("error_policy", error_policy),
|
||||
("max_skip_ratio", max_skip_ratio),
|
||||
("retries", retries),
|
||||
("docker_image", docker_image),
|
||||
]:
|
||||
if val is not None:
|
||||
self.options[key] = str(val)
|
||||
# Keep the source in the description (when available) so the
|
||||
# catalog stays inspectable even for pickled bodies.
|
||||
if description is not None:
|
||||
self.options["description"] = description
|
||||
else:
|
||||
try:
|
||||
self.options["description"] = textwrap.dedent(inspect.getsource(fn))
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Call with real values to run locally; call with column-name
|
||||
strings to build an expression for backfills and views, e.g.
|
||||
``embed("data")`` -> the expression ``embed(data)`` (a `ColumnExpr`
|
||||
carrying the function's return type for `add_columns(computed=...)`)."""
|
||||
if args and all(isinstance(a, str) for a in args) and not kwargs:
|
||||
return self.expression(*args)
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def expression(self, *columns: str) -> ColumnExpr:
|
||||
"""The expression applying this function to `columns` (default: the
|
||||
function's own parameter names). Returns a `ColumnExpr` -- a string
|
||||
that also carries the declared return type (and struct field types)."""
|
||||
cols = columns or [p for p, _ in self.params]
|
||||
expr = f"{self.name}({', '.join(cols)})"
|
||||
field_types = None
|
||||
if self.returns.upper().startswith("STRUCT"):
|
||||
field_types = struct_field_types(self.returns)
|
||||
return ColumnExpr(expr, data_type=self.returns, field_types=field_types)
|
||||
|
||||
def _body(self) -> "tuple[str, str]":
|
||||
"""(body literal, body_format). Source when requested and
|
||||
retrievable; cloudpickle otherwise (handles closures)."""
|
||||
if self.prefer_source:
|
||||
try:
|
||||
src = textwrap.dedent(inspect.getsource(self.fn))
|
||||
# Strip the decorator line(s) so the stored body is a
|
||||
# plain function definition.
|
||||
lines = src.splitlines(keepends=True)
|
||||
while lines and lines[0].lstrip().startswith("@"):
|
||||
lines.pop(0)
|
||||
return "".join(lines), "source"
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
import cloudpickle
|
||||
|
||||
raw = cloudpickle.dumps(self.fn)
|
||||
return base64.b64encode(raw).decode("ascii"), "cloudpickle"
|
||||
|
||||
def _body_and_options(self) -> "tuple[str, dict[str, str]]":
|
||||
"""The body literal plus the finalized options (body_format /
|
||||
python_version / cloudpickle-pip bookkeeping for a non-source
|
||||
body)."""
|
||||
body, body_format = self._body()
|
||||
options = dict(self.options)
|
||||
if body_format != "source":
|
||||
options["body_format"] = body_format
|
||||
# Pickled code objects only load under the same interpreter
|
||||
# minor version; record ours so the worker can fail with a
|
||||
# clear message instead of a bytecode error.
|
||||
options["python_version"] = self.pickle_environment()
|
||||
# The worker deserializes the body with cloudpickle; make sure
|
||||
# the job's pip environment provides it. Conda bakes inject
|
||||
# cloudpickle server-side, so do not create an invalid pip+conda
|
||||
# declaration here.
|
||||
if "conda" not in options:
|
||||
pip = [d for d in options.get("pip", "").split(",") if d]
|
||||
if not any(d.startswith("cloudpickle") for d in pip):
|
||||
pip.append("cloudpickle")
|
||||
options["pip"] = ",".join(pip)
|
||||
return body, options
|
||||
|
||||
def create_request(self) -> dict:
|
||||
"""Keyword arguments for `connection.create_function`."""
|
||||
body, options = self._body_and_options()
|
||||
return {
|
||||
"name": self.name,
|
||||
"language": "python",
|
||||
"return_type": self.returns,
|
||||
"body": body,
|
||||
"options": options,
|
||||
}
|
||||
|
||||
def create_statement(self) -> str:
|
||||
"""The equivalent `CREATE FUNCTION` SQL (for SQL-surface callers)."""
|
||||
params = ", ".join(f"{n} {t}" for n, t in self.params)
|
||||
body, options = self._body_and_options()
|
||||
with_clause = ""
|
||||
if options:
|
||||
rendered = ", ".join(
|
||||
f"{k} = '{_escape(v)}'" for k, v in sorted(options.items())
|
||||
)
|
||||
with_clause = f" WITH ({rendered})"
|
||||
return (
|
||||
f"CREATE FUNCTION {self.name}({params}) RETURNS {self.returns} "
|
||||
f"LANGUAGE python AS '{_escape_body(body)}'{with_clause}"
|
||||
)
|
||||
|
||||
def pickle_environment(self) -> str:
|
||||
"""Python version the body pickles under -- workers should match
|
||||
the minor version for cloudpickle compatibility."""
|
||||
return f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
|
||||
def _escape(s: str) -> str:
|
||||
return str(s).replace("'", "''")
|
||||
|
||||
|
||||
def _format_requirements(requirements: "str | list[str]") -> str:
|
||||
if isinstance(requirements, str):
|
||||
return requirements
|
||||
return "\n".join(str(req) for req in requirements)
|
||||
|
||||
|
||||
def _format_env(env: "dict[str, str] | list[str]") -> str:
|
||||
if isinstance(env, dict):
|
||||
return "; ".join(f"{key}={value}" for key, value in env.items())
|
||||
return "; ".join(str(entry) for entry in env)
|
||||
|
||||
|
||||
def _escape_body(body: str) -> str:
|
||||
# The server unescapes \n / \t in single-quoted bodies; encode real
|
||||
# newlines accordingly and escape quotes.
|
||||
return (
|
||||
body.replace("\\", "\\\\")
|
||||
.replace("'", "''")
|
||||
.replace("\n", "\\n")
|
||||
.replace("\t", "\\t")
|
||||
)
|
||||
|
||||
|
||||
def udf(fn=None, **kwargs):
|
||||
"""Decorate a function as a scalar (or struct-returning) UDF.
|
||||
|
||||
@udf
|
||||
def doubled(val: int) -> float: ...
|
||||
|
||||
@udf(pip=["torch>=2"], num_gpus=1)
|
||||
def embed(body: str) -> list[float]: ...
|
||||
"""
|
||||
if fn is not None:
|
||||
return Udf(fn, **kwargs)
|
||||
return lambda f: Udf(f, **kwargs)
|
||||
|
||||
|
||||
def table_udf(fn=None, **kwargs):
|
||||
"""Decorate a table function (UDTF): each input row may emit zero or
|
||||
more output rows. Only usable in materialized views.
|
||||
|
||||
class Chunk(TypedDict):
|
||||
chunk: str
|
||||
chunk_idx: int
|
||||
|
||||
@table_udf
|
||||
def chunker(body: str) -> list[Chunk]: ...
|
||||
"""
|
||||
kwargs["table"] = True
|
||||
if fn is not None:
|
||||
return Udf(fn, **kwargs)
|
||||
return lambda f: Udf(f, **kwargs)
|
||||
|
||||
|
||||
# -- view / job handles (thin references over a connection) -------------
|
||||
|
||||
|
||||
def struct_field_types(returns: str) -> "list[str]":
|
||||
"""Field type strings of a STRUCT(...) SQL type, in declared order."""
|
||||
inner = returns.strip()[len("STRUCT(") : -1]
|
||||
fields, depth, start = [], 0, 0
|
||||
for i, c in enumerate(inner):
|
||||
if c in "([":
|
||||
depth += 1
|
||||
elif c in ")]":
|
||||
depth -= 1
|
||||
elif c == "," and depth == 0:
|
||||
fields.append(inner[start:i].strip())
|
||||
start = i + 1
|
||||
fields.append(inner[start:].strip())
|
||||
# Each field is "name TYPE"; drop the name.
|
||||
return [f.split(None, 1)[1] for f in fields]
|
||||
|
||||
|
||||
def build_view_query(source, select) -> str:
|
||||
"""Assemble a view SELECT from a source (name or table) and select
|
||||
items: a column name, an expression string, a (alias, expression)
|
||||
tuple, or a @udf/@table_udf object."""
|
||||
src = source.name if hasattr(source, "name") else source
|
||||
items = []
|
||||
for item in select:
|
||||
if isinstance(item, Udf):
|
||||
items.append(item.expression())
|
||||
elif isinstance(item, tuple):
|
||||
alias, expr = item
|
||||
expr = expr.expression() if isinstance(expr, Udf) else expr
|
||||
items.append(f"{expr} AS {alias}")
|
||||
else:
|
||||
items.append(item)
|
||||
return f"SELECT {', '.join(items)} FROM {src}"
|
||||
|
||||
|
||||
def _job_id_matches(handle_id: str, listed_id: str) -> bool:
|
||||
# The refresh/backfill endpoints return the submission id (a uuid), but
|
||||
# the agent names the manifest job "<table>-<type>-<first 8 of the
|
||||
# submission id>" -- which is what list_jobs and cancel report. Match the
|
||||
# canonical id directly, or by that submission prefix.
|
||||
if listed_id == handle_id:
|
||||
return True
|
||||
prefix = handle_id[:8]
|
||||
return len(prefix) >= 4 and prefix in listed_id
|
||||
|
||||
|
||||
class MaterializedView:
|
||||
"""A reference to a materialized view (name + connection). Operations are
|
||||
server-backed connection calls bound to the name.
|
||||
|
||||
``create_materialized_view`` returns one of these; ``job_id`` is the
|
||||
initial-population job (None when the view was created with no data), so
|
||||
``db.create_materialized_view(...).wait()`` blocks until it is populated.
|
||||
"""
|
||||
|
||||
def __init__(self, conn, name: str, job_id: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.name = name
|
||||
#: initial-population job id from create, or None (with_no_data).
|
||||
self.job_id = job_id
|
||||
|
||||
def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
"""Block until the initial-population job (from create) finishes.
|
||||
A no-op when the view was created with no data."""
|
||||
if self.job_id is None:
|
||||
return "finished"
|
||||
return JobHandle(self.conn, self.job_id, table=self.name).wait(
|
||||
timeout=timeout, poll=poll
|
||||
)
|
||||
|
||||
def refresh(self, full: bool = False) -> "JobHandle":
|
||||
"""Refresh the materialized view; returns a `JobHandle` to wait on,
|
||||
poll, or cancel (``view.refresh().wait()``).
|
||||
|
||||
``full=True`` forces a full rebuild (recompute and replace every row)
|
||||
instead of the default incremental refresh. A full rebuild preserves
|
||||
the view's indexes -- they are reindexed by the distributed indexer.
|
||||
"""
|
||||
job_id = self.conn._refresh_materialized_view(self.name, full=full)
|
||||
return JobHandle(self.conn, job_id, table=self.name)
|
||||
|
||||
def explain_refresh(self, full: bool = False):
|
||||
"""Plan a refresh without running it (EXPLAIN REFRESH)."""
|
||||
return self.conn.explain_refresh_materialized_view(self.name, full=full)
|
||||
|
||||
def alter(self, auto_refresh: bool) -> None:
|
||||
self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
|
||||
|
||||
def drop(self) -> None:
|
||||
self.conn.drop_materialized_view(self.name)
|
||||
|
||||
# A materialized view is a first-class table: it can be indexed and
|
||||
# searched like any other. These open the materialized dataset by name and
|
||||
# delegate. Indexes declared this way are recorded against the view, so the
|
||||
# engine re-applies them after a full refresh rebuilds the dataset (a full
|
||||
# refresh overwrites the dataset, which would otherwise drop its indices).
|
||||
def _table(self):
|
||||
return self.conn.open_table(self.name)
|
||||
|
||||
def create_index(self, *args, **kwargs):
|
||||
"""Build an index on the materialized view (see Table.create_index)."""
|
||||
return self._table().create_index(*args, **kwargs)
|
||||
|
||||
def create_scalar_index(self, *args, **kwargs):
|
||||
"""Build a scalar index on the materialized view."""
|
||||
return self._table().create_scalar_index(*args, **kwargs)
|
||||
|
||||
def create_fts_index(self, *args, **kwargs):
|
||||
"""Build a full-text-search index on the materialized view."""
|
||||
return self._table().create_fts_index(*args, **kwargs)
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
"""Search the materialized view (vector / FTS / hybrid)."""
|
||||
return self._table().search(*args, **kwargs)
|
||||
|
||||
def lineage(self, column=None, *, direction=None, depth=None):
|
||||
"""Lineage of the materialized view (or one of its columns). Delegates
|
||||
to the backing table; the server already includes the view's sources
|
||||
and downstream dependents. Returns a `Lineage`."""
|
||||
return self._table().lineage(column, direction=direction, depth=depth)
|
||||
|
||||
|
||||
_PROGRESS = re.compile(r"(\d+)/(\d+)")
|
||||
|
||||
|
||||
class JobFailedError(RuntimeError):
|
||||
"""Raised by ``JobHandle.wait()`` when the server reports the job ``failed``.
|
||||
|
||||
Carries the server-side error so a doomed backfill (e.g. a multi-column
|
||||
``REFRESH COLUMN`` of a scalar UDF) surfaces its real cause promptly,
|
||||
instead of the caller blocking until ``wait()``'s timeout.
|
||||
"""
|
||||
|
||||
def __init__(self, job_id: str, error: "str | None"):
|
||||
self.job_id = job_id
|
||||
self.error = error
|
||||
super().__init__(f"job {job_id} failed: {error or 'unknown error'}")
|
||||
|
||||
|
||||
class JobHandle:
|
||||
"""A reference to an inflight server-side job, with polling helpers."""
|
||||
|
||||
#: How long an unseen job is treated as still materializing (submission
|
||||
#: -> agent cycle -> manifest write is async).
|
||||
GRACE_SECONDS = 20.0
|
||||
|
||||
def __init__(self, conn, job_id: str, table: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.id = job_id
|
||||
#: The job's table, when known (refresh_column / MV refresh). Lets the
|
||||
#: server resolve this job with an O(1) single-node read; without it the
|
||||
#: lookup scans the database's active jobs (still correct).
|
||||
self.table = table
|
||||
self._created = time.monotonic()
|
||||
self._seen = False
|
||||
|
||||
def _job(self):
|
||||
# Poll by id (one job), not list_jobs (every active job): the server
|
||||
# matches the submission/manifest id and reads just this table's node.
|
||||
return self.conn.get_job(self.id, self.table)
|
||||
|
||||
def status(self) -> str:
|
||||
"""pending / running / cancelling / stale, or 'finished' once the
|
||||
job has left the inflight listing."""
|
||||
job = self._job()
|
||||
if job is not None:
|
||||
self._seen = True
|
||||
return job.state
|
||||
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
|
||||
return "pending"
|
||||
return "finished"
|
||||
|
||||
def progress(self) -> "tuple[int, int] | None":
|
||||
"""(units_done, units_total) while running, else None."""
|
||||
job = self._job()
|
||||
if job is not None and job.units_total is not None:
|
||||
return job.units_done or 0, job.units_total
|
||||
return None
|
||||
|
||||
def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
state = self.status()
|
||||
if state in ("finished", "stale"):
|
||||
return state
|
||||
if state == "failed":
|
||||
# Terminal failure -- surface the server error now, don't block
|
||||
# until `timeout`. `finalize` wrote it to the job's status node.
|
||||
job = self._job()
|
||||
raise JobFailedError(self.id, job.error if job is not None else None)
|
||||
if state == "pending":
|
||||
time.sleep(min(poll, 0.5))
|
||||
continue
|
||||
job = self._job()
|
||||
if job is not None and job.committed:
|
||||
return "finished"
|
||||
time.sleep(poll)
|
||||
raise TimeoutError(f"job {self.id} still {self.status()} after {timeout}s")
|
||||
|
||||
def cancel(self) -> None:
|
||||
# Cancel by the canonical manifest id (what cancel matches), found
|
||||
# via the submission prefix; fall back to the raw id.
|
||||
job = self._job()
|
||||
self.conn.cancel_job(job.job_id if job is not None else self.id)
|
||||
|
||||
|
||||
class AsyncMaterializedView:
|
||||
"""Async reference to a materialized view (name + async connection)."""
|
||||
|
||||
def __init__(self, conn, name: str, job_id: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.name = name
|
||||
#: initial-population job id from create, or None (with_no_data).
|
||||
self.job_id = job_id
|
||||
|
||||
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
"""Block until the initial-population job (from create) finishes.
|
||||
A no-op when the view was created with no data."""
|
||||
if self.job_id is None:
|
||||
return "finished"
|
||||
return await AsyncJobHandle(self.conn, self.job_id, table=self.name).wait(
|
||||
timeout=timeout, poll=poll
|
||||
)
|
||||
|
||||
async def refresh(self, full: bool = False) -> "AsyncJobHandle":
|
||||
"""Refresh the materialized view; returns an `AsyncJobHandle` to wait
|
||||
on, poll, or cancel.
|
||||
|
||||
``full=True`` forces a full rebuild instead of an incremental refresh
|
||||
(indexes are preserved and reindexed by the distributed indexer).
|
||||
"""
|
||||
job_id = await self.conn._refresh_materialized_view(self.name, full=full)
|
||||
return AsyncJobHandle(self.conn, job_id, table=self.name)
|
||||
|
||||
async def explain_refresh(self, full: bool = False):
|
||||
return await self.conn.explain_refresh_materialized_view(self.name, full=full)
|
||||
|
||||
async def alter(self, auto_refresh: bool) -> None:
|
||||
await self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
|
||||
|
||||
async def drop(self) -> None:
|
||||
await self.conn.drop_materialized_view(self.name)
|
||||
|
||||
async def lineage(self, column=None, *, direction=None, depth=None):
|
||||
"""Lineage of the materialized view (or column). Returns a `Lineage`."""
|
||||
return await self.conn.lineage(
|
||||
self.name, column, direction=direction, depth=depth
|
||||
)
|
||||
|
||||
|
||||
class AsyncJobHandle:
|
||||
"""Async reference to an inflight server-side job, with polling helpers."""
|
||||
|
||||
GRACE_SECONDS = 20.0
|
||||
|
||||
def __init__(self, conn, job_id: str, table: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.id = job_id
|
||||
#: See JobHandle.table -- enables an O(1) by-id lookup when known.
|
||||
self.table = table
|
||||
self._created = time.monotonic()
|
||||
self._seen = False
|
||||
|
||||
async def _job(self):
|
||||
# Poll by id, not list_jobs (see JobHandle._job).
|
||||
return await self.conn.get_job(self.id, self.table)
|
||||
|
||||
async def status(self) -> str:
|
||||
job = await self._job()
|
||||
if job is not None:
|
||||
self._seen = True
|
||||
return job.state
|
||||
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
|
||||
return "pending"
|
||||
return "finished"
|
||||
|
||||
async def progress(self) -> "tuple[int, int] | None":
|
||||
job = await self._job()
|
||||
if job is not None and job.units_total is not None:
|
||||
return job.units_done or 0, job.units_total
|
||||
return None
|
||||
|
||||
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
state = await self.status()
|
||||
if state in ("finished", "stale"):
|
||||
return state
|
||||
if state == "failed":
|
||||
# Terminal failure -- surface the server error now, don't block
|
||||
# until `timeout`. `finalize` wrote it to the job's status node.
|
||||
job = await self._job()
|
||||
raise JobFailedError(self.id, job.error if job is not None else None)
|
||||
if state == "pending":
|
||||
await asyncio.sleep(min(poll, 0.5))
|
||||
continue
|
||||
job = await self._job()
|
||||
if job is not None and job.committed:
|
||||
return "finished"
|
||||
await asyncio.sleep(poll)
|
||||
raise TimeoutError(
|
||||
f"job {self.id} still {await self.status()} after {timeout}s"
|
||||
)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
job = await self._job()
|
||||
await self.conn.cancel_job(job.job_id if job is not None else self.id)
|
||||
92
python/python/tests/test_job_handle.py
Normal file
92
python/python/tests/test_job_handle.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
"""JobHandle.wait() terminal-state handling.
|
||||
|
||||
Regression coverage for the cluster backfill-failure hang: the server reports a
|
||||
doomed job as ``state="failed"`` within seconds, but ``wait()`` used to ignore
|
||||
``failed`` and block until its (default 3600s) timeout. These tests pin that a
|
||||
``failed`` job raises ``JobFailedError`` promptly, carrying the server error.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from lancedb.udf import JobHandle, AsyncJobHandle, JobFailedError
|
||||
|
||||
|
||||
class FakeJobInfo:
|
||||
"""Mirror of the pyo3 builtins.JobInfo fields wait()/status() read."""
|
||||
|
||||
def __init__(self, state, error=None, committed=False, units_total=None):
|
||||
self.state = state
|
||||
self.error = error
|
||||
self.committed = committed
|
||||
self.units_total = units_total
|
||||
self.units_done = None
|
||||
self.job_id = "job-1"
|
||||
|
||||
|
||||
class FakeConn:
|
||||
"""get_job() walks a scripted list of JobInfo (or None) snapshots, holding
|
||||
the last one once exhausted, so wait() polls a deterministic timeline."""
|
||||
|
||||
def __init__(self, snapshots):
|
||||
self._snaps = list(snapshots)
|
||||
self.calls = 0
|
||||
|
||||
def get_job(self, job_id, table=None):
|
||||
snap = self._snaps[min(self.calls, len(self._snaps) - 1)]
|
||||
self.calls += 1
|
||||
return snap
|
||||
|
||||
|
||||
class AsyncFakeConn(FakeConn):
|
||||
async def get_job(self, job_id, table=None):
|
||||
return FakeConn.get_job(self, job_id, table)
|
||||
|
||||
|
||||
def test_wait_raises_on_failed_promptly():
|
||||
# pending -> failed: wait() must raise the server error, not TimeoutError.
|
||||
conn = FakeConn(
|
||||
[None, FakeJobInfo("failed", error="multi-column backfill needs a STRUCT")]
|
||||
)
|
||||
jh = JobHandle(conn, "job-1", table="t")
|
||||
t0 = time.monotonic()
|
||||
with pytest.raises(JobFailedError) as exc:
|
||||
jh.wait(timeout=30, poll=0.01)
|
||||
assert time.monotonic() - t0 < 5 # prompt, nowhere near the 30s timeout
|
||||
assert "STRUCT" in str(exc.value)
|
||||
assert exc.value.error == "multi-column backfill needs a STRUCT"
|
||||
assert exc.value.job_id == "job-1"
|
||||
|
||||
|
||||
def test_wait_returns_finished_on_success():
|
||||
# running -> finished (job left the inflight listing) returns normally.
|
||||
conn = FakeConn([FakeJobInfo("running", units_total=2), None])
|
||||
jh = JobHandle(conn, "job-1", table="t")
|
||||
jh._seen = True # already observed, so a None now means "finished" not grace
|
||||
assert jh.wait(timeout=30, poll=0.01) == "finished"
|
||||
|
||||
|
||||
def test_wait_returns_finished_on_committed():
|
||||
# A committed job that is still listed resolves to finished.
|
||||
conn = FakeConn([FakeJobInfo("running", committed=True, units_total=2)])
|
||||
jh = JobHandle(conn, "job-1", table="t")
|
||||
jh._seen = True
|
||||
assert jh.wait(timeout=30, poll=0.01) == "finished"
|
||||
|
||||
|
||||
def test_async_wait_raises_on_failed_promptly():
|
||||
conn = AsyncFakeConn([None, FakeJobInfo("failed", error="boom")])
|
||||
jh = AsyncJobHandle(conn, "job-1", table="t")
|
||||
|
||||
async def run():
|
||||
t0 = time.monotonic()
|
||||
with pytest.raises(JobFailedError) as exc:
|
||||
await jh.wait(timeout=30, poll=0.01)
|
||||
assert time.monotonic() - t0 < 5
|
||||
assert exc.value.error == "boom"
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -65,6 +65,9 @@ def _namespace_lance_table(namespace_client: _NamespaceClient) -> LanceTable:
|
||||
table._namespace_path = ["geneva"]
|
||||
table._namespace_client = namespace_client
|
||||
table._pushdown_operations = {"QueryTable"}
|
||||
# This test exercises the Python-side pushdown path (non-native client), so
|
||||
# pushdown is not routed to Rust.
|
||||
table._route_pushdown_to_rust = False
|
||||
return table
|
||||
|
||||
|
||||
@@ -805,6 +808,37 @@ class TestPushdownOperations:
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
assert len(db._namespace_client_pushdown_operations) == 0
|
||||
|
||||
def test_route_pushdown_to_rust_for_native_rest(self):
|
||||
"""A natively-built rest connection must defer QueryTable pushdown to
|
||||
Rust so reads carry the x-lancedb-min-timestamp read-freshness header."""
|
||||
db = lancedb.connect_namespace(
|
||||
"rest",
|
||||
{"uri": "http://localhost:12345"},
|
||||
namespace_client_pushdown_operations=["QueryTable"],
|
||||
)
|
||||
assert db._route_pushdown_to_rust is True
|
||||
|
||||
def test_route_pushdown_to_rust_false_for_dir(self):
|
||||
"""A non-native (dir) connection keeps the Python pushdown path."""
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
assert db._route_pushdown_to_rust is False
|
||||
|
||||
def test_async_route_pushdown_to_rust_for_native_rest(self):
|
||||
"""The async connection must not silently bypass the read-freshness fix:
|
||||
a natively-built rest connection defers pushdown to Rust (regression test
|
||||
for the async path omitting the freshness header)."""
|
||||
db = lancedb.connect_namespace_async(
|
||||
"rest",
|
||||
{"uri": "http://localhost:12345"},
|
||||
namespace_client_pushdown_operations=["QueryTable"],
|
||||
)
|
||||
assert db._route_pushdown_to_rust is True
|
||||
|
||||
def test_async_route_pushdown_to_rust_false_for_dir(self):
|
||||
"""The async non-native (dir) connection keeps the Python pushdown path."""
|
||||
db = lancedb.connect_namespace_async("dir", {"root": self.temp_dir})
|
||||
assert db._route_pushdown_to_rust is False
|
||||
|
||||
def test_lance_table_to_arrow_uses_query_pushdown(self):
|
||||
namespace_client = _NamespaceClient()
|
||||
table = _namespace_lance_table(namespace_client)
|
||||
|
||||
@@ -18,7 +18,10 @@ use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
connection::NamespaceClientPushdownOperation,
|
||||
database::namespace::LanceNamespaceDatabase,
|
||||
database::{CreateTableMode, Database, ReadConsistency},
|
||||
database::{
|
||||
CreateFunctionRequest, CreateMaterializedViewRequest, CreateTableMode, Database,
|
||||
ReadConsistency, RefreshMaterializedViewRequest, TableLineageRequest,
|
||||
},
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
@@ -27,6 +30,92 @@ use pyo3::{
|
||||
types::{PyDict, PyDictMethods},
|
||||
};
|
||||
|
||||
/// A registered function, as returned by `list_functions`.
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct FunctionInfo {
|
||||
pub name: String,
|
||||
pub language: String,
|
||||
pub return_type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// A registered materialized view definition.
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct MaterializedViewInfo {
|
||||
pub name: String,
|
||||
pub source_table: String,
|
||||
pub projection: Vec<String>,
|
||||
pub udf_columns: Vec<String>,
|
||||
pub filter: Option<String>,
|
||||
pub auto_refresh: bool,
|
||||
}
|
||||
|
||||
/// One inflight server-side job.
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct JobInfo {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub age_seconds: Option<i64>,
|
||||
pub command: Option<String>,
|
||||
pub units_done: Option<i64>,
|
||||
pub units_total: Option<i64>,
|
||||
pub committed: bool,
|
||||
pub rows_skipped: u64,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// One durable, completed/terminal server-side job record (SHOW JOB HISTORY).
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct JobHistoryEntry {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub created_ms: i64,
|
||||
pub updated_ms: i64,
|
||||
pub completed_ms: Option<i64>,
|
||||
pub rows_processed: Option<i64>,
|
||||
pub rows_skipped: Option<i64>,
|
||||
pub error: Option<String>,
|
||||
pub events: Option<String>,
|
||||
}
|
||||
|
||||
/// One per-row UDF error recorded by `error_policy=skip` (SHOW ERRORS).
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct JobErrorEntry {
|
||||
pub job_id: String,
|
||||
pub table: String,
|
||||
pub column: String,
|
||||
pub error_type: String,
|
||||
pub error_message: String,
|
||||
pub fragment_id: Option<i64>,
|
||||
pub source_row_id: Option<i64>,
|
||||
pub table_version: Option<i64>,
|
||||
pub age_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
/// The plan a REFRESH MATERIALIZED VIEW would execute (EXPLAIN REFRESH).
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct MvRefreshPlan {
|
||||
pub table_name: String,
|
||||
pub has_work: bool,
|
||||
pub source_version: u64,
|
||||
pub last_refreshed_version: Option<u64>,
|
||||
pub full_refresh: bool,
|
||||
pub rebuild: bool,
|
||||
pub units_total: u64,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct Connection {
|
||||
inner: Option<LanceConnection>,
|
||||
@@ -310,6 +399,308 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, language, return_type, body, options=None))]
|
||||
pub fn create_function(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
language: String,
|
||||
return_type: String,
|
||||
body: String,
|
||||
options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.create_function(CreateFunctionRequest {
|
||||
name,
|
||||
language,
|
||||
return_type,
|
||||
body,
|
||||
options: options.unwrap_or_default(),
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_functions(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let functions = inner.list_functions().await.infer_error()?;
|
||||
Ok(functions
|
||||
.into_iter()
|
||||
.map(|f| FunctionInfo {
|
||||
name: f.name,
|
||||
language: f.language,
|
||||
return_type: f.return_type,
|
||||
description: f.description,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_function(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_function(&name).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, query, auto_refresh=false, with_no_data=false, partition_by=None))]
|
||||
pub fn create_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
query: String,
|
||||
auto_refresh: bool,
|
||||
with_no_data: bool,
|
||||
partition_by: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.create_materialized_view(CreateMaterializedViewRequest {
|
||||
name,
|
||||
query,
|
||||
auto_refresh,
|
||||
with_no_data,
|
||||
partition_by,
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, full=false, src_version=None, num_workers=None, max_workers=None))]
|
||||
pub fn refresh_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.refresh_materialized_view(RefreshMaterializedViewRequest {
|
||||
name,
|
||||
full,
|
||||
src_version,
|
||||
num_workers,
|
||||
max_workers,
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
/// Derived-compute lineage of a table/view (or column), returned as the
|
||||
/// server's lineage JSON string (the Python layer parses it).
|
||||
pub fn table_lineage(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
column: Option<String>,
|
||||
direction: Option<String>,
|
||||
depth: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.table_lineage(TableLineageRequest {
|
||||
name,
|
||||
column,
|
||||
direction,
|
||||
depth,
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, full=false, src_version=None))]
|
||||
pub fn explain_refresh_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let p = inner
|
||||
.explain_refresh_materialized_view(&name, full, src_version)
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(MvRefreshPlan {
|
||||
table_name: p.table_name,
|
||||
has_work: p.has_work,
|
||||
source_version: p.source_version,
|
||||
last_refreshed_version: p.last_refreshed_version,
|
||||
full_refresh: p.full_refresh,
|
||||
rebuild: p.rebuild,
|
||||
units_total: p.units_total,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn alter_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
auto_refresh: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.alter_materialized_view(&name, auto_refresh)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_materialized_view(&name).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_materialized_views(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let views = inner.list_materialized_views().await.infer_error()?;
|
||||
Ok(views
|
||||
.into_iter()
|
||||
.map(|v| MaterializedViewInfo {
|
||||
name: v.name,
|
||||
source_table: v.source_table,
|
||||
projection: v.projection,
|
||||
udf_columns: v.udf_columns,
|
||||
filter: v.filter,
|
||||
auto_refresh: v.auto_refresh,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_jobs(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let jobs = inner.list_jobs().await.infer_error()?;
|
||||
Ok(jobs
|
||||
.into_iter()
|
||||
.map(|j| JobInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
age_seconds: j.age_seconds,
|
||||
command: j.command,
|
||||
units_done: j.units_done,
|
||||
units_total: j.units_total,
|
||||
committed: j.committed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cancel_job(self_: PyRef<'_, Self>, job_id: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.cancel_job(&job_id).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (job_id, table=None))]
|
||||
pub fn get_job(
|
||||
self_: PyRef<'_, Self>,
|
||||
job_id: String,
|
||||
table: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let job = inner
|
||||
.get_job(&job_id, table.as_deref())
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(job.map(|j| JobInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
age_seconds: j.age_seconds,
|
||||
command: j.command,
|
||||
units_done: j.units_done,
|
||||
units_total: j.units_total,
|
||||
committed: j.committed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (job_id=None))]
|
||||
pub fn job_history(
|
||||
self_: PyRef<'_, Self>,
|
||||
job_id: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let rows = inner.job_history(job_id.as_deref()).await.infer_error()?;
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| JobHistoryEntry {
|
||||
table: r.table,
|
||||
job_id: r.job_id,
|
||||
job_type: r.job_type,
|
||||
state: r.state,
|
||||
column: r.column,
|
||||
created_ms: r.created_ms,
|
||||
updated_ms: r.updated_ms,
|
||||
completed_ms: r.completed_ms,
|
||||
rows_processed: r.rows_processed,
|
||||
rows_skipped: r.rows_skipped,
|
||||
error: r.error,
|
||||
events: r.events,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (job_id=None, table=None))]
|
||||
pub fn errors(
|
||||
self_: PyRef<'_, Self>,
|
||||
job_id: Option<String>,
|
||||
table: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let rows = inner
|
||||
.errors(job_id.as_deref(), table.as_deref())
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|e| JobErrorEntry {
|
||||
job_id: e.job_id,
|
||||
table: e.table,
|
||||
column: e.column,
|
||||
error_type: e.error_type,
|
||||
error_message: e.error_message,
|
||||
fragment_id: e.fragment_id,
|
||||
source_row_id: e.source_row_id,
|
||||
table_version: e.table_version,
|
||||
age_seconds: e.age_seconds,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (cur_name, new_name, cur_namespace_path=None, new_namespace_path=None))]
|
||||
pub fn rename_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -539,7 +930,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn connect(
|
||||
py: Python<'_>,
|
||||
@@ -553,6 +944,7 @@ pub fn connect(
|
||||
session: Option<crate::session::Session>,
|
||||
manifest_enabled: bool,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
oauth_config: Option<crate::oauth::PyOAuthConfig>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
@@ -582,6 +974,11 @@ pub fn connect(
|
||||
if let Some(client_config) = client_config {
|
||||
builder = builder.client_config(client_config.into());
|
||||
}
|
||||
if let Some(oauth_config) = oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig =
|
||||
oauth_config.try_into().infer_error()?;
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
if let Some(session) = session {
|
||||
builder = builder.session(session.inner.clone());
|
||||
}
|
||||
@@ -610,24 +1007,38 @@ pub fn connect_namespace_client(
|
||||
namespace_client_impl: Option<String>,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Connection> {
|
||||
let namespace_client = extract_namespace_arc(py, namespace_client)?;
|
||||
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
|
||||
let namespace_client_pushdown_operations =
|
||||
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
|
||||
let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string());
|
||||
let ns_properties = namespace_client_properties.unwrap_or_default();
|
||||
let storage_options = storage_options.unwrap_or_default();
|
||||
let session = session.map(|s| s.inner.clone());
|
||||
|
||||
let database = LanceNamespaceDatabase::from_namespace_client(
|
||||
namespace_client,
|
||||
ns_impl,
|
||||
ns_properties,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
namespace_client_pushdown_operations,
|
||||
);
|
||||
// Prefer building the namespace natively from (impl, properties) so the
|
||||
// read-freshness provider installed
|
||||
let database = if build_namespace_natively(namespace_client_impl.as_deref(), &ns_properties) {
|
||||
let ns_impl = namespace_client_impl.expect("impl present per build_namespace_natively");
|
||||
crate::runtime::block_on(LanceNamespaceDatabase::connect(
|
||||
&ns_impl,
|
||||
ns_properties,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
namespace_client_pushdown_operations,
|
||||
))
|
||||
.infer_error()?
|
||||
} else {
|
||||
let namespace_client = extract_namespace_arc(py, namespace_client)?;
|
||||
LanceNamespaceDatabase::from_namespace_client(
|
||||
namespace_client,
|
||||
namespace_client_impl.unwrap_or_else(|| "python".to_string()),
|
||||
ns_properties,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
namespace_client_pushdown_operations,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Connection::new(LanceConnection::new(
|
||||
Arc::new(database),
|
||||
@@ -635,6 +1046,16 @@ pub fn connect_namespace_client(
|
||||
)))
|
||||
}
|
||||
|
||||
/// Whether to build the namespace natively (from impl + properties) instead of
|
||||
/// wrapping a pre-built client. Native construction is required for the
|
||||
/// read-freshness provider to be installed
|
||||
fn build_namespace_natively(
|
||||
namespace_client_impl: Option<&str>,
|
||||
namespace_client_properties: &HashMap<String, String>,
|
||||
) -> bool {
|
||||
matches!(namespace_client_impl, Some("rest")) && !namespace_client_properties.is_empty()
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyClientConfig {
|
||||
user_agent: String,
|
||||
@@ -733,3 +1154,36 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn props(pairs: &[(&str, &str)]) -> HashMap<String, String> {
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_build_only_for_rest_with_properties() {
|
||||
let rest = props(&[("uri", "http://localhost:10024")]);
|
||||
|
||||
// rest + non-empty properties -> build natively (installs the
|
||||
// read-freshness provider so checkout_latest() busts the server cache).
|
||||
assert!(build_namespace_natively(Some("rest"), &rest));
|
||||
|
||||
// dir is local (no server cache) -> wrap the pre-built client unchanged.
|
||||
assert!(!build_namespace_natively(
|
||||
Some("dir"),
|
||||
&props(&[("root", "/tmp")])
|
||||
));
|
||||
|
||||
// No impl: only a pre-built client was handed in -> wrap it as-is.
|
||||
assert!(!build_namespace_natively(None, &rest));
|
||||
|
||||
// rest but no properties: nothing to build a connection from -> wrap.
|
||||
assert!(!build_namespace_natively(Some("rest"), &HashMap::new()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ pub mod expr;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod namespace;
|
||||
pub mod oauth;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod runtime;
|
||||
@@ -40,6 +41,11 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
.write_style("LANCEDB_LOG_STYLE");
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_class::<connection::FunctionInfo>()?;
|
||||
m.add_class::<connection::MaterializedViewInfo>()?;
|
||||
m.add_class::<connection::JobInfo>()?;
|
||||
m.add_class::<connection::JobHistoryEntry>()?;
|
||||
m.add_class::<connection::JobErrorEntry>()?;
|
||||
m.add_class::<Session>()?;
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<IndexConfig>()?;
|
||||
|
||||
72
python/src/oauth.rs
Normal file
72
python/src/oauth.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use pyo3::FromPyObject;
|
||||
|
||||
use lancedb::error::Error;
|
||||
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
|
||||
|
||||
/// Python-side OAuth configuration, extracted via FromPyObject.
|
||||
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyOAuthConfig {
|
||||
pub issuer_url: String,
|
||||
pub client_id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub flow: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl TryFrom<PyOAuthConfig> for OAuthConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(py: PyOAuthConfig) -> Result<Self, Self::Error> {
|
||||
let flow = match py.flow.as_str() {
|
||||
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: py.managed_identity_client_id,
|
||||
},
|
||||
other => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("Unknown OAuth flow type: {other}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
issuer_url: py.issuer_url,
|
||||
client_id: py.client_id,
|
||||
client_secret: py.client_secret,
|
||||
scopes: py.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: py.refresh_buffer_secs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_unknown_oauth_flow_returns_invalid_input() {
|
||||
let config = PyOAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: "typo".to_string(),
|
||||
client_secret: None,
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = OAuthConfig::try_from(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "Unknown OAuth flow type: typo"
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -56,6 +56,15 @@ fn get_runtime() -> &'static runtime::Runtime {
|
||||
unsafe { &*new_ptr }
|
||||
}
|
||||
|
||||
/// Block the current thread on a future using the shared runtime.
|
||||
///
|
||||
/// For sync `#[pyfunction]`s that need to drive an async operation (e.g.
|
||||
/// building a namespace client). Must not be called from within the runtime's
|
||||
/// own worker threads.
|
||||
pub fn block_on<F: std::future::Future>(fut: F) -> F::Output {
|
||||
get_runtime().block_on(fut)
|
||||
}
|
||||
|
||||
/// Runs in async-signal context after `fork()` in the child. We can only
|
||||
/// touch atomics here; we deliberately leak the previous runtime because
|
||||
/// dropping a tokio `Runtime` would try to join its (now-dead) worker
|
||||
|
||||
@@ -17,8 +17,8 @@ use arrow::{
|
||||
pyarrow::{FromPyArrow, PyArrowType, ToPyArrow},
|
||||
};
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform,
|
||||
OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
|
||||
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, LoadColumnsRequest,
|
||||
NewColumnTransform, OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
@@ -1060,6 +1060,83 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_computed_columns(
|
||||
self_: PyRef<'_, Self>,
|
||||
columns: Vec<(String, String)>,
|
||||
expression: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.add_computed_columns(&columns, &expression)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, where_clause=None, num_workers=None, max_workers=None, batch_size=None, priority=None))]
|
||||
pub fn refresh_column(
|
||||
self_: PyRef<'_, Self>,
|
||||
columns: Vec<String>,
|
||||
where_clause: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.refresh_column(
|
||||
&columns,
|
||||
where_clause,
|
||||
num_workers,
|
||||
max_workers,
|
||||
batch_size,
|
||||
priority,
|
||||
)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[pyo3(signature = (source_uris, source_format, target_key, columns, source_key=None, source_storage_options=None, on_missing=None, num_workers=None, max_workers=None, batch_size=None, commit_granularity=None, priority=None))]
|
||||
pub fn load_columns(
|
||||
self_: PyRef<'_, Self>,
|
||||
source_uris: Vec<String>,
|
||||
source_format: String,
|
||||
target_key: String,
|
||||
columns: Vec<(String, Option<String>)>,
|
||||
source_key: Option<String>,
|
||||
source_storage_options: Option<std::collections::HashMap<String, String>>,
|
||||
on_missing: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
commit_granularity: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
let request = LoadColumnsRequest {
|
||||
source_uris,
|
||||
source_format,
|
||||
source_storage_options,
|
||||
target_key,
|
||||
source_key,
|
||||
columns,
|
||||
on_missing,
|
||||
num_workers,
|
||||
max_workers,
|
||||
batch_size,
|
||||
commit_granularity,
|
||||
priority,
|
||||
};
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.load_columns(request).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_columns(
|
||||
self_: PyRef<'_, Self>,
|
||||
definitions: Vec<(String, String)>,
|
||||
|
||||
33
python/tests/test_oauth.py
Normal file
33
python/tests/test_oauth.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_oauth_module():
|
||||
oauth_path = (
|
||||
Path(__file__).parents[1] / "python" / "lancedb" / "remote" / "oauth.py"
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location("lancedb_remote_oauth", oauth_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_oauth_config_repr_redacts_client_secret():
|
||||
oauth = _load_oauth_module()
|
||||
|
||||
config = oauth.OAuthConfig(
|
||||
issuer_url="https://issuer.example.com",
|
||||
client_id="client-id",
|
||||
scopes=["scope"],
|
||||
client_secret="super-secret",
|
||||
)
|
||||
|
||||
rendered = repr(config)
|
||||
assert "super-secret" not in rendered
|
||||
assert "client_secret" not in rendered
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.31.0-beta.1"
|
||||
version = "0.31.0-beta.4"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -50,7 +50,7 @@ lance-namespace = { workspace = true }
|
||||
lance-namespace-impls = { workspace = true }
|
||||
moka = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread", "sync"] }
|
||||
log.workspace = true
|
||||
async-trait = "0"
|
||||
bytes = "1"
|
||||
@@ -75,6 +75,7 @@ reqwest = { version = "0.12.0", default-features = false, features = [
|
||||
"stream",
|
||||
], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
urlencoding = { version = "2", optional = true }
|
||||
uuid = { version = "1.7.0", features = ["v4", "v5"] }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
@@ -93,6 +94,7 @@ semver = { workspace = true }
|
||||
anyhow = "1"
|
||||
tempfile = "3.5.0"
|
||||
random_word = { version = "0.4.3", features = ["en"] }
|
||||
tokio = { version = "1.23", features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
walkdir = "2"
|
||||
aws-sdk-dynamodb = { version = "1.55.0" }
|
||||
@@ -129,7 +131,13 @@ huggingface = [
|
||||
"lance-namespace-impls/dir-huggingface",
|
||||
]
|
||||
dynamodb = ["lance/dynamodb", "aws"]
|
||||
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||
remote = [
|
||||
"dep:reqwest",
|
||||
"dep:http",
|
||||
"dep:urlencoding",
|
||||
"lance-namespace-impls/rest",
|
||||
"lance-namespace-impls/rest-adapter",
|
||||
]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||
|
||||
@@ -23,8 +23,10 @@ use crate::connection::create_table::CreateTableBuilder;
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::database::listing::ListingDatabase;
|
||||
use crate::database::{
|
||||
CloneTableRequest, Database, DatabaseOptions, OpenTableRequest, ReadConsistency,
|
||||
TableNamesRequest,
|
||||
CloneTableRequest, CreateFunctionRequest, CreateMaterializedViewRequest, Database,
|
||||
DatabaseOptions, FunctionInfo, JobErrorInfo, JobHistoryInfo, JobInfo, MaterializedViewInfo,
|
||||
MvRefreshPlan, OpenTableRequest, ReadConsistency, RefreshMaterializedViewRequest,
|
||||
TableLineageRequest, TableNamesRequest,
|
||||
};
|
||||
use crate::embeddings::{EmbeddingRegistry, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
@@ -488,6 +490,113 @@ impl Connection {
|
||||
)
|
||||
}
|
||||
|
||||
// -- Derived compute: functions, materialized views, jobs -------------
|
||||
// Server-backed features (LanceDB Enterprise / Cloud); local
|
||||
// databases return NotSupported for now.
|
||||
|
||||
/// Register a UDF (CREATE FUNCTION).
|
||||
pub async fn create_function(&self, request: CreateFunctionRequest) -> Result<()> {
|
||||
self.internal.create_function(request).await
|
||||
}
|
||||
|
||||
/// List registered functions (SHOW FUNCTIONS).
|
||||
pub async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
|
||||
self.internal.list_functions().await
|
||||
}
|
||||
|
||||
/// Drop a registered function (DROP FUNCTION).
|
||||
pub async fn drop_function(&self, name: &str) -> Result<()> {
|
||||
self.internal.drop_function(name).await
|
||||
}
|
||||
|
||||
/// Create a materialized view (CREATE MATERIALIZED VIEW). Returns
|
||||
/// the initial-population job id, absent when `with_no_data`.
|
||||
pub async fn create_materialized_view(
|
||||
&self,
|
||||
request: CreateMaterializedViewRequest,
|
||||
) -> Result<Option<String>> {
|
||||
self.internal.create_materialized_view(request).await
|
||||
}
|
||||
|
||||
/// Refresh a materialized view; returns the refresh job id.
|
||||
pub async fn refresh_materialized_view(
|
||||
&self,
|
||||
request: RefreshMaterializedViewRequest,
|
||||
) -> Result<String> {
|
||||
self.internal.refresh_materialized_view(request).await
|
||||
}
|
||||
|
||||
/// Derived-compute lineage of a table/view (or column), as server-defined
|
||||
/// JSON. Read-only.
|
||||
pub async fn table_lineage(&self, request: TableLineageRequest) -> Result<String> {
|
||||
self.internal.table_lineage(request).await
|
||||
}
|
||||
|
||||
/// Plan a materialized-view refresh without submitting work
|
||||
/// (EXPLAIN REFRESH).
|
||||
pub async fn explain_refresh_materialized_view(
|
||||
&self,
|
||||
name: &str,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
) -> Result<MvRefreshPlan> {
|
||||
self.internal
|
||||
.explain_refresh_materialized_view(name, full, src_version)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Update a materialized view's options (ALTER MATERIALIZED VIEW).
|
||||
pub async fn alter_materialized_view(&self, name: &str, auto_refresh: bool) -> Result<()> {
|
||||
self.internal
|
||||
.alter_materialized_view(name, auto_refresh)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Drop a materialized view definition (DROP MATERIALIZED VIEW).
|
||||
pub async fn drop_materialized_view(&self, name: &str) -> Result<()> {
|
||||
self.internal.drop_materialized_view(name).await
|
||||
}
|
||||
|
||||
/// List registered materialized view definitions.
|
||||
pub async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
|
||||
self.internal.list_materialized_views().await
|
||||
}
|
||||
|
||||
/// List inflight server-side jobs across the database's tables.
|
||||
pub async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
|
||||
self.internal.list_jobs().await
|
||||
}
|
||||
|
||||
/// Cancel an inflight server-side job by id. Returns true if a
|
||||
/// matching inflight job was flagged for cancellation.
|
||||
pub async fn cancel_job(&self, job_id: &str) -> Result<bool> {
|
||||
self.internal.cancel_job(job_id).await
|
||||
}
|
||||
|
||||
/// Look up a single server-side job by id -- the `wait()`/status poll path.
|
||||
/// `table_hint` (the job's table) enables an O(1) server-side lookup; `None`
|
||||
/// scans the database's active jobs. A `None` result means unknown / not
|
||||
/// active.
|
||||
pub async fn get_job(&self, job_id: &str, table_hint: Option<&str>) -> Result<Option<JobInfo>> {
|
||||
self.internal.get_job(job_id, table_hint).await
|
||||
}
|
||||
|
||||
/// Durable job history (SHOW JOB HISTORY) across the database's tables.
|
||||
/// Pass `job_id` to narrow to a single job.
|
||||
pub async fn job_history(&self, job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
|
||||
self.internal.job_history(job_id).await
|
||||
}
|
||||
|
||||
/// Per-row UDF errors (SHOW ERRORS) across the database's tables, optionally
|
||||
/// filtered by `job_id` and/or `table`.
|
||||
pub async fn errors(
|
||||
&self,
|
||||
job_id: Option<&str>,
|
||||
table: Option<&str>,
|
||||
) -> Result<Vec<JobErrorInfo>> {
|
||||
self.internal.errors(job_id, table).await
|
||||
}
|
||||
|
||||
/// Rename a table in the database.
|
||||
///
|
||||
/// This is only supported in LanceDB Cloud.
|
||||
@@ -576,6 +685,9 @@ impl Connection {
|
||||
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
|
||||
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
|
||||
/// For RemoteDatabase, it is the equivalent RestNamespace.
|
||||
///
|
||||
/// Remote connections using dynamic headers forward them through the
|
||||
/// namespace client's per-request context provider.
|
||||
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
|
||||
self.internal.namespace_client().await
|
||||
}
|
||||
@@ -584,6 +696,9 @@ impl Connection {
|
||||
/// Returns (impl_type, properties) where:
|
||||
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||
/// - properties: configuration properties for the namespace
|
||||
///
|
||||
/// Remote connections using dynamic headers cannot be exported because the
|
||||
/// namespace client config only carries static headers.
|
||||
pub async fn namespace_client_config(
|
||||
&self,
|
||||
) -> Result<(String, std::collections::HashMap<String, String>)> {
|
||||
@@ -661,6 +776,8 @@ pub struct ConnectRequest {
|
||||
pub struct ConnectBuilder {
|
||||
request: ConnectRequest,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: Option<crate::remote::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
@@ -682,6 +799,8 @@ impl ConnectBuilder {
|
||||
session: None,
|
||||
},
|
||||
embedding_registry: None,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -770,6 +889,19 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
|
||||
///
|
||||
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
|
||||
/// from the given config and sets it as the header provider. OAuth cannot
|
||||
/// be combined with an API key or another header provider.
|
||||
///
|
||||
/// Token acquisition and refresh are handled in Rust.
|
||||
#[cfg(feature = "remote")]
|
||||
pub fn oauth_config(mut self, config: crate::remote::OAuthConfig) -> Self {
|
||||
self.oauth_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
@@ -915,9 +1047,40 @@ impl ConnectBuilder {
|
||||
let region = options.region.ok_or_else(|| Error::InvalidInput {
|
||||
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
let api_key = match (&self.oauth_config, &options.api_key) {
|
||||
(Some(_), None) => String::new(),
|
||||
(Some(_), Some(_)) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"api_key and oauth_config cannot both be set when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
(None, Some(key)) => key.clone(),
|
||||
(None, None) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if self.oauth_config.is_some() && self.request.client_config.header_provider.is_some() {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut client_config = self.request.client_config;
|
||||
|
||||
if let Some(oauth_config) = self.oauth_config {
|
||||
let provider = crate::remote::OAuthHeaderProvider::new(oauth_config)?;
|
||||
client_config.header_provider =
|
||||
Some(Arc::new(provider) as Arc<dyn crate::remote::HeaderProvider>);
|
||||
}
|
||||
|
||||
let storage_options = StorageOptions(options.storage_options.clone());
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
@@ -925,7 +1088,7 @@ impl ConnectBuilder {
|
||||
&api_key,
|
||||
®ion,
|
||||
options.host_override,
|
||||
self.request.client_config,
|
||||
client_config,
|
||||
storage_options.into(),
|
||||
self.request.read_consistency_interval,
|
||||
)?);
|
||||
@@ -1234,6 +1397,83 @@ mod tests {
|
||||
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[tokio::test]
|
||||
async fn test_connect_rejects_api_key_with_oauth_config() {
|
||||
let oauth_config = crate::remote::OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: crate::remote::OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||
.region("us-east-1")
|
||||
.api_key("my-api-key")
|
||||
.oauth_config(oauth_config)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidInput { message })
|
||||
if message
|
||||
== "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud" =>
|
||||
{}
|
||||
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||
Ok(_) => panic!("expected api_key and oauth_config to be rejected"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[tokio::test]
|
||||
async fn test_connect_rejects_header_provider_with_oauth_config() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl crate::remote::HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let oauth_config = crate::remote::OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: crate::remote::OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let client_config = crate::remote::ClientConfig {
|
||||
header_provider: Some(
|
||||
Arc::new(TestHeaderProvider) as Arc<dyn crate::remote::HeaderProvider>
|
||||
),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||
.region("us-east-1")
|
||||
.client_config(client_config)
|
||||
.oauth_config(oauth_config)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidInput { message })
|
||||
if message
|
||||
== "oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud" =>
|
||||
{}
|
||||
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||
Ok(_) => panic!("expected header_provider and oauth_config to be rejected"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn test_connect_relative() {
|
||||
|
||||
@@ -27,7 +27,7 @@ use lance_namespace::models::{
|
||||
};
|
||||
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::error::Result;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::table::{BaseTable, WriteOptions};
|
||||
|
||||
pub mod listing;
|
||||
@@ -200,6 +200,205 @@ pub enum ReadConsistency {
|
||||
Strong,
|
||||
}
|
||||
|
||||
/// A request to register a UDF (CREATE FUNCTION).
|
||||
///
|
||||
/// Functions are first-class database objects, decoupled from any
|
||||
/// column; computed columns and materialized views reference them by
|
||||
/// name. Server-backed feature (LanceDB Enterprise / Cloud).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreateFunctionRequest {
|
||||
/// Function name.
|
||||
pub name: String,
|
||||
/// Implementation language (currently "python").
|
||||
pub language: String,
|
||||
/// SQL return type, e.g. `FLOAT`, `FLOAT[1536]`,
|
||||
/// `STRUCT(a FLOAT, b VARCHAR)`, `TABLE(chunk VARCHAR, idx INT)`.
|
||||
pub return_type: String,
|
||||
/// Function body: source text, or base64 cloudpickle bytes when
|
||||
/// `options["body_format"] = "cloudpickle"`.
|
||||
pub body: String,
|
||||
/// Options: input_columns, pip, num_gpus, batch_size, timeout,
|
||||
/// error_policy, docker_image, body_format, ...
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// A registered function, as returned by `list_functions`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FunctionInfo {
|
||||
pub name: String,
|
||||
pub language: String,
|
||||
pub return_type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// A request to create a materialized view (CREATE MATERIALIZED VIEW).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreateMaterializedViewRequest {
|
||||
/// View name.
|
||||
pub name: String,
|
||||
/// The view's SELECT statement, e.g.
|
||||
/// `SELECT id, embed(body) AS vec FROM articles WHERE id > 1`.
|
||||
/// Bare columns project through; function-call columns compute via
|
||||
/// registered UDFs (a RETURNS TABLE function makes a row-expanding
|
||||
/// chunker view).
|
||||
pub query: String,
|
||||
/// Refresh automatically when the source table changes.
|
||||
pub auto_refresh: bool,
|
||||
/// Register the definition only; skip the initial population.
|
||||
pub with_no_data: bool,
|
||||
/// Optional source column to partition the view's table function on. If the
|
||||
/// column has an IVF vector index the server partitions by its clusters
|
||||
/// (image-dedup style); otherwise it groups by distinct value.
|
||||
pub partition_by: Option<String>,
|
||||
}
|
||||
|
||||
impl CreateMaterializedViewRequest {
|
||||
pub fn new(name: impl Into<String>, query: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
query: query.into(),
|
||||
auto_refresh: false,
|
||||
with_no_data: false,
|
||||
partition_by: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A request to refresh a materialized view.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RefreshMaterializedViewRequest {
|
||||
/// View name.
|
||||
pub name: String,
|
||||
/// Force a full rebuild (recompute and replace every row) instead of the
|
||||
/// default incremental refresh.
|
||||
pub full: bool,
|
||||
/// Pin the refresh to a source-table version; latest when absent.
|
||||
pub src_version: Option<u64>,
|
||||
/// Initial worker count.
|
||||
pub num_workers: Option<u32>,
|
||||
/// Elastic worker ceiling.
|
||||
pub max_workers: Option<u32>,
|
||||
}
|
||||
|
||||
/// A request for the derived-compute lineage of a table/view (or one of its
|
||||
/// columns). The response is server-defined lineage JSON, returned opaque so
|
||||
/// this client need not model the server's lineage schema.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TableLineageRequest {
|
||||
/// Table or view name.
|
||||
pub name: String,
|
||||
/// Column for column-level lineage; whole table/view when absent.
|
||||
pub column: Option<String>,
|
||||
/// "upstream" | "downstream" | "both" (server default when absent).
|
||||
pub direction: Option<String>,
|
||||
/// Column-hops to walk; transitive when absent.
|
||||
pub depth: Option<u32>,
|
||||
}
|
||||
|
||||
impl RefreshMaterializedViewRequest {
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
full: false,
|
||||
src_version: None,
|
||||
num_workers: None,
|
||||
max_workers: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A registered materialized view definition, as returned by
|
||||
/// `list_materialized_views`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MaterializedViewInfo {
|
||||
pub name: String,
|
||||
pub source_table: String,
|
||||
/// Source columns projected through.
|
||||
pub projection: Vec<String>,
|
||||
/// `alias=expression` per UDF-computed column.
|
||||
pub udf_columns: Vec<String>,
|
||||
pub filter: Option<String>,
|
||||
pub auto_refresh: bool,
|
||||
}
|
||||
|
||||
/// A row from `list_jobs`: one inflight server-side job (index build,
|
||||
/// compaction, column refresh, view refresh, ...).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JobInfo {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
/// Lifecycle state: "running", "cancelling", or "stale".
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub age_seconds: Option<i64>,
|
||||
pub command: Option<String>,
|
||||
pub units_done: Option<i64>,
|
||||
pub units_total: Option<i64>,
|
||||
/// Whether the job's final commit has completed (output visible).
|
||||
pub committed: bool,
|
||||
pub rows_skipped: u64,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// A row from `job_history`: one durable, completed/terminal server-side job
|
||||
/// record (SHOW JOB HISTORY), read from a table's `_job_history` store. Unlike
|
||||
/// `JobInfo` (live, inflight jobs) this carries created/updated/completed
|
||||
/// timestamps and the lifecycle event log.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JobHistoryInfo {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub created_ms: i64,
|
||||
pub updated_ms: i64,
|
||||
pub completed_ms: Option<i64>,
|
||||
pub rows_processed: Option<i64>,
|
||||
pub rows_skipped: Option<i64>,
|
||||
pub error: Option<String>,
|
||||
/// Newline-joined lifecycle event log, oldest first.
|
||||
pub events: Option<String>,
|
||||
}
|
||||
|
||||
/// A row from `errors`: one per-row UDF failure recorded by `error_policy=skip`
|
||||
/// (SHOW ERRORS).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JobErrorInfo {
|
||||
pub job_id: String,
|
||||
pub table: String,
|
||||
pub column: String,
|
||||
pub error_type: String,
|
||||
pub error_message: String,
|
||||
pub fragment_id: Option<i64>,
|
||||
pub source_row_id: Option<i64>,
|
||||
pub table_version: Option<i64>,
|
||||
pub age_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
/// The plan a `REFRESH MATERIALIZED VIEW` would execute, as returned by
|
||||
/// `explain_refresh_materialized_view` (EXPLAIN REFRESH). No work is run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MvRefreshPlan {
|
||||
pub table_name: String,
|
||||
/// Whether a refresh would do anything (rebuild or non-empty units).
|
||||
pub has_work: bool,
|
||||
pub source_version: u64,
|
||||
pub last_refreshed_version: Option<u64>,
|
||||
pub full_refresh: bool,
|
||||
/// Source changed non-append-only since the last refresh -> rebuild.
|
||||
pub rebuild: bool,
|
||||
/// Number of row-range work units the refresh would process.
|
||||
pub units_total: u64,
|
||||
}
|
||||
|
||||
fn not_supported<T>(what: &str) -> Result<T> {
|
||||
Err(Error::NotSupported {
|
||||
message: format!("{} is not supported by this database", what),
|
||||
})
|
||||
}
|
||||
|
||||
/// The `Database` trait defines the interface for database implementations.
|
||||
///
|
||||
/// A database is responsible for managing tables and their metadata.
|
||||
@@ -245,6 +444,99 @@ pub trait Database:
|
||||
///
|
||||
/// See [`CloneTableRequest`] for detailed documentation and examples.
|
||||
async fn clone_table(&self, request: CloneTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
|
||||
// -- Derived compute: functions, materialized views, jobs -------------
|
||||
//
|
||||
// Server-backed features (LanceDB Enterprise / Cloud). The defaults
|
||||
// return NotSupported; the remote database overrides them. Local
|
||||
// single-node implementations are planned.
|
||||
|
||||
/// Register a UDF (CREATE FUNCTION).
|
||||
async fn create_function(&self, _request: CreateFunctionRequest) -> Result<()> {
|
||||
not_supported("create_function")
|
||||
}
|
||||
/// List registered functions (SHOW FUNCTIONS).
|
||||
async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
|
||||
not_supported("list_functions")
|
||||
}
|
||||
/// Drop a registered function (DROP FUNCTION).
|
||||
async fn drop_function(&self, _name: &str) -> Result<()> {
|
||||
not_supported("drop_function")
|
||||
}
|
||||
/// Create a materialized view (CREATE MATERIALIZED VIEW). Returns
|
||||
/// the initial-population job id, absent when `with_no_data`.
|
||||
async fn create_materialized_view(
|
||||
&self,
|
||||
_request: CreateMaterializedViewRequest,
|
||||
) -> Result<Option<String>> {
|
||||
not_supported("create_materialized_view")
|
||||
}
|
||||
/// Refresh a materialized view; returns the refresh job id.
|
||||
async fn refresh_materialized_view(
|
||||
&self,
|
||||
_request: RefreshMaterializedViewRequest,
|
||||
) -> Result<String> {
|
||||
not_supported("refresh_materialized_view")
|
||||
}
|
||||
/// Derived-compute lineage of a table/view (or column), as server-defined
|
||||
/// JSON. Read-only.
|
||||
async fn table_lineage(&self, _request: TableLineageRequest) -> Result<String> {
|
||||
not_supported("table_lineage")
|
||||
}
|
||||
/// Plan a materialized-view refresh without submitting work
|
||||
/// (EXPLAIN REFRESH). `full` plans a full rebuild (incremental
|
||||
/// planning requires stable row IDs on the source).
|
||||
async fn explain_refresh_materialized_view(
|
||||
&self,
|
||||
_name: &str,
|
||||
_full: bool,
|
||||
_src_version: Option<u64>,
|
||||
) -> Result<MvRefreshPlan> {
|
||||
not_supported("explain_refresh_materialized_view")
|
||||
}
|
||||
/// Update a materialized view's options (ALTER MATERIALIZED VIEW).
|
||||
async fn alter_materialized_view(&self, _name: &str, _auto_refresh: bool) -> Result<()> {
|
||||
not_supported("alter_materialized_view")
|
||||
}
|
||||
/// Drop a materialized view definition (DROP MATERIALIZED VIEW).
|
||||
async fn drop_materialized_view(&self, _name: &str) -> Result<()> {
|
||||
not_supported("drop_materialized_view")
|
||||
}
|
||||
/// List registered materialized view definitions.
|
||||
async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
|
||||
not_supported("list_materialized_views")
|
||||
}
|
||||
/// List inflight server-side jobs across the database's tables.
|
||||
async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
|
||||
not_supported("list_jobs")
|
||||
}
|
||||
/// Cancel an inflight server-side job by id. Returns true if a
|
||||
/// matching inflight job was found and flagged for cancellation,
|
||||
/// false if none was inflight (best-effort, like SQL `CANCEL JOB`).
|
||||
async fn cancel_job(&self, _job_id: &str) -> Result<bool> {
|
||||
not_supported("cancel_job")
|
||||
}
|
||||
/// Point-access for a single job by id -- the `wait()`/status poll path.
|
||||
/// `table_hint` (the job's table, which `wait()` callers know) enables an
|
||||
/// O(1) server-side lookup. `None` if the job is unknown or not active.
|
||||
async fn get_job(&self, _job_id: &str, _table_hint: Option<&str>) -> Result<Option<JobInfo>> {
|
||||
not_supported("get_job")
|
||||
}
|
||||
/// Durable job history (SHOW JOB HISTORY) across the database's tables,
|
||||
/// optionally narrowed to a single `job_id`.
|
||||
async fn job_history(&self, _job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
|
||||
not_supported("job_history")
|
||||
}
|
||||
/// Per-row UDF errors (SHOW ERRORS) recorded by `error_policy=skip` across
|
||||
/// the database's tables, optionally filtered by `job_id` and/or `table`.
|
||||
async fn errors(
|
||||
&self,
|
||||
_job_id: Option<&str>,
|
||||
_table: Option<&str>,
|
||||
) -> Result<Vec<JobErrorInfo>> {
|
||||
not_supported("errors")
|
||||
}
|
||||
|
||||
/// Open a table in the database
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
/// Rename a table in the database
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod db;
|
||||
pub mod oauth;
|
||||
mod retry;
|
||||
pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
@@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
||||
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
||||
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};
|
||||
|
||||
@@ -459,12 +459,14 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
config: &ClientConfig,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
if !api_key.is_empty() {
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
if region == "local" {
|
||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||
headers.insert(
|
||||
@@ -1005,6 +1007,33 @@ mod tests {
|
||||
assert!(!config_tls.assert_hostname);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_headers_skip_empty_api_key() {
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"",
|
||||
"us-east-1",
|
||||
"db-name",
|
||||
false,
|
||||
&RemoteOptions::default(),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(!headers.contains_key("x-api-key"));
|
||||
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"api-key",
|
||||
"us-east-1",
|
||||
"db-name",
|
||||
false,
|
||||
&RemoteOptions::default(),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(headers.get("x-api-key").unwrap(), "api-key");
|
||||
}
|
||||
|
||||
// Test implementation of HeaderProvider
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestHeaderProvider {
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use lance_io::object_store::StorageOptions;
|
||||
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||
use moka::future::Cache;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
|
||||
@@ -18,18 +19,264 @@ use lance_namespace::models::{
|
||||
|
||||
use crate::Error;
|
||||
use crate::database::{
|
||||
CloneTableRequest, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||
OpenTableRequest, ReadConsistency, TableNamesRequest,
|
||||
CloneTableRequest, CreateFunctionRequest, CreateMaterializedViewRequest, CreateTableMode,
|
||||
CreateTableRequest, Database, DatabaseOptions, FunctionInfo, JobErrorInfo, JobHistoryInfo,
|
||||
JobInfo, MaterializedViewInfo, MvRefreshPlan, OpenTableRequest, ReadConsistency,
|
||||
RefreshMaterializedViewRequest, TableLineageRequest, TableNamesRequest,
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::remote::util::stream_as_body;
|
||||
use crate::table::BaseTable;
|
||||
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::client::{
|
||||
ClientConfig, HeaderProvider, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender,
|
||||
};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::parse_server_version;
|
||||
|
||||
// Wire types for the derived-compute routes (functions, materialized
|
||||
// views, jobs). Field shapes mirror the server's REST contract.
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteCreateFunctionRequest {
|
||||
language: String,
|
||||
return_type: String,
|
||||
body: String,
|
||||
options: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteFunctionEntry {
|
||||
name: String,
|
||||
language: String,
|
||||
return_type: String,
|
||||
#[serde(default)]
|
||||
description: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteListFunctionsResponse {
|
||||
functions: Vec<RemoteFunctionEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteCreateMaterializedViewRequest {
|
||||
query: String,
|
||||
auto_refresh: bool,
|
||||
with_no_data: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
partition_by: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteCreateMaterializedViewResponse {
|
||||
#[serde(default)]
|
||||
job_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteRefreshMaterializedViewRequest {
|
||||
#[serde(skip_serializing_if = "std::ops::Not::not")]
|
||||
full: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
src_version: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
num_workers: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_workers: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteRefreshMaterializedViewResponse {
|
||||
job_id: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteExplainRefreshRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
full: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
src_version: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteExplainRefreshResponse {
|
||||
table_name: String,
|
||||
has_work: bool,
|
||||
source_version: u64,
|
||||
last_refreshed_version: Option<u64>,
|
||||
full_refresh: bool,
|
||||
rebuild: bool,
|
||||
units_total: u64,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteAlterMaterializedViewRequest {
|
||||
auto_refresh: bool,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteMaterializedViewEntry {
|
||||
name: String,
|
||||
source_table: String,
|
||||
#[serde(default)]
|
||||
projection: Vec<String>,
|
||||
#[serde(default)]
|
||||
udf_columns: Vec<String>,
|
||||
#[serde(default)]
|
||||
filter: Option<String>,
|
||||
#[serde(default)]
|
||||
auto_refresh: bool,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteListMaterializedViewsResponse {
|
||||
views: Vec<RemoteMaterializedViewEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteJobEntry {
|
||||
table: String,
|
||||
job_id: String,
|
||||
job_type: String,
|
||||
state: String,
|
||||
#[serde(default)]
|
||||
column: Option<String>,
|
||||
#[serde(default)]
|
||||
age_seconds: Option<i64>,
|
||||
#[serde(default)]
|
||||
command: Option<String>,
|
||||
#[serde(default)]
|
||||
units_done: Option<i64>,
|
||||
#[serde(default)]
|
||||
units_total: Option<i64>,
|
||||
#[serde(default)]
|
||||
committed: bool,
|
||||
#[serde(default)]
|
||||
rows_skipped: u64,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteListJobsResponse {
|
||||
jobs: Vec<RemoteJobEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteGetJobResponse {
|
||||
#[serde(default)]
|
||||
job: Option<RemoteJobEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteCancelJobResponse {
|
||||
cancelled: bool,
|
||||
}
|
||||
|
||||
impl From<RemoteJobEntry> for JobInfo {
|
||||
fn from(j: RemoteJobEntry) -> Self {
|
||||
JobInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
age_seconds: j.age_seconds,
|
||||
command: j.command,
|
||||
units_done: j.units_done,
|
||||
units_total: j.units_total,
|
||||
committed: j.committed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteJobHistoryEntry {
|
||||
table: String,
|
||||
job_id: String,
|
||||
job_type: String,
|
||||
state: String,
|
||||
#[serde(default)]
|
||||
column: Option<String>,
|
||||
created_ms: i64,
|
||||
updated_ms: i64,
|
||||
#[serde(default)]
|
||||
completed_ms: Option<i64>,
|
||||
#[serde(default)]
|
||||
rows_processed: Option<i64>,
|
||||
#[serde(default)]
|
||||
rows_skipped: Option<i64>,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
#[serde(default)]
|
||||
events: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteJobHistoryResponse {
|
||||
jobs: Vec<RemoteJobHistoryEntry>,
|
||||
}
|
||||
|
||||
impl From<RemoteJobHistoryEntry> for JobHistoryInfo {
|
||||
fn from(j: RemoteJobHistoryEntry) -> Self {
|
||||
JobHistoryInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
created_ms: j.created_ms,
|
||||
updated_ms: j.updated_ms,
|
||||
completed_ms: j.completed_ms,
|
||||
rows_processed: j.rows_processed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
events: j.events,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteErrorEntry {
|
||||
job_id: String,
|
||||
table: String,
|
||||
column: String,
|
||||
error_type: String,
|
||||
error_message: String,
|
||||
#[serde(default)]
|
||||
fragment_id: Option<i64>,
|
||||
#[serde(default)]
|
||||
source_row_id: Option<i64>,
|
||||
#[serde(default)]
|
||||
table_version: Option<i64>,
|
||||
#[serde(default)]
|
||||
age_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteErrorsResponse {
|
||||
errors: Vec<RemoteErrorEntry>,
|
||||
}
|
||||
|
||||
impl From<RemoteErrorEntry> for JobErrorInfo {
|
||||
fn from(e: RemoteErrorEntry) -> Self {
|
||||
JobErrorInfo {
|
||||
job_id: e.job_id,
|
||||
table: e.table,
|
||||
column: e.column,
|
||||
error_type: e.error_type,
|
||||
error_message: e.error_message,
|
||||
fragment_id: e.fragment_id,
|
||||
source_row_id: e.source_row_id,
|
||||
table_version: e.table_version,
|
||||
age_seconds: e.age_seconds,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Request structure for the remote clone table API
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteCloneTableRequest {
|
||||
@@ -194,10 +441,66 @@ pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
uri: String,
|
||||
/// Headers to pass to the namespace client for authentication
|
||||
namespace_headers: HashMap<String, String>,
|
||||
namespace_context_provider: Option<Arc<dyn DynamicContextProvider>>,
|
||||
/// TLS configuration for mTLS support
|
||||
tls_config: Option<super::client::TlsConfig>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NamespaceHeaderProviderContext {
|
||||
header_provider: Arc<dyn HeaderProvider>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NamespaceHeaderProviderContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("NamespaceHeaderProviderContext")
|
||||
.field("header_provider", &"Some(...)")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl DynamicContextProvider for NamespaceHeaderProviderContext {
|
||||
fn provide_context(&self, _info: &OperationInfo) -> HashMap<String, String> {
|
||||
let header_provider = Arc::clone(&self.header_provider);
|
||||
let handle = match std::thread::Builder::new()
|
||||
.name("lancedb-namespace-headers".to_string())
|
||||
.spawn(move || {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!(
|
||||
"Failed to create runtime for namespace header provider: {e}"
|
||||
),
|
||||
})?
|
||||
.block_on(header_provider.get_headers())
|
||||
}) {
|
||||
Ok(handle) => handle,
|
||||
Err(err) => {
|
||||
log::warn!("Failed to spawn dynamic namespace header provider thread: {err}");
|
||||
return HashMap::new();
|
||||
}
|
||||
};
|
||||
|
||||
let headers = handle.join();
|
||||
|
||||
match headers {
|
||||
Ok(Ok(headers)) => headers
|
||||
.into_iter()
|
||||
.map(|(key, value)| (format!("headers.{key}"), value))
|
||||
.collect(),
|
||||
Ok(Err(err)) => {
|
||||
log::warn!("Failed to get dynamic namespace headers: {err}");
|
||||
HashMap::new()
|
||||
}
|
||||
Err(_) => {
|
||||
log::warn!("Dynamic namespace header provider panicked");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
pub fn try_new(
|
||||
uri: &str,
|
||||
@@ -228,6 +531,16 @@ impl RemoteDatabase {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let namespace_context_provider =
|
||||
client_config
|
||||
.header_provider
|
||||
.as_ref()
|
||||
.map(|header_provider| {
|
||||
Arc::new(NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::clone(header_provider),
|
||||
}) as Arc<dyn DynamicContextProvider>
|
||||
});
|
||||
|
||||
let client = RestfulLanceDbClient::try_new(
|
||||
&parsed,
|
||||
region,
|
||||
@@ -247,6 +560,7 @@ impl RemoteDatabase {
|
||||
table_cache,
|
||||
uri: uri.to_owned(),
|
||||
namespace_headers,
|
||||
namespace_context_provider,
|
||||
tls_config: client_config.tls_config,
|
||||
})
|
||||
}
|
||||
@@ -271,6 +585,7 @@ mod test_utils {
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: HashMap::new(),
|
||||
namespace_context_provider: None,
|
||||
tls_config: None,
|
||||
}
|
||||
}
|
||||
@@ -281,11 +596,18 @@ mod test_utils {
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler_and_config(handler, config.clone());
|
||||
let namespace_context_provider =
|
||||
config.header_provider.as_ref().map(|header_provider| {
|
||||
Arc::new(NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::clone(header_provider),
|
||||
}) as Arc<dyn DynamicContextProvider>
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: config.extra_headers.clone(),
|
||||
namespace_context_provider,
|
||||
tls_config: config.tls_config.clone(),
|
||||
}
|
||||
}
|
||||
@@ -563,6 +885,228 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
async fn create_function(&self, request: CreateFunctionRequest) -> Result<()> {
|
||||
let body = RemoteCreateFunctionRequest {
|
||||
language: request.language,
|
||||
return_type: request.return_type,
|
||||
body: request.body,
|
||||
options: request.options,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/function/{}/create", request.name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
|
||||
let req = self.client.get("/v1/function/list");
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteListFunctionsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|f| FunctionInfo {
|
||||
name: f.name,
|
||||
language: f.language,
|
||||
return_type: f.return_type,
|
||||
description: f.description,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn drop_function(&self, name: &str) -> Result<()> {
|
||||
let req = self.client.post(&format!("/v1/function/{}/drop", name));
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_materialized_view(
|
||||
&self,
|
||||
request: CreateMaterializedViewRequest,
|
||||
) -> Result<Option<String>> {
|
||||
let body = RemoteCreateMaterializedViewRequest {
|
||||
query: request.query,
|
||||
auto_refresh: request.auto_refresh,
|
||||
with_no_data: request.with_no_data,
|
||||
partition_by: request.partition_by,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/create", request.name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteCreateMaterializedViewResponse =
|
||||
rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn refresh_materialized_view(
|
||||
&self,
|
||||
request: RefreshMaterializedViewRequest,
|
||||
) -> Result<String> {
|
||||
let body = RemoteRefreshMaterializedViewRequest {
|
||||
full: request.full,
|
||||
src_version: request.src_version,
|
||||
num_workers: request.num_workers,
|
||||
max_workers: request.max_workers,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/refresh", request.name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteRefreshMaterializedViewResponse =
|
||||
rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn table_lineage(&self, request: TableLineageRequest) -> Result<String> {
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/lineage", request.name));
|
||||
if let Some(column) = &request.column {
|
||||
req = req.query(&[("column", column)]);
|
||||
}
|
||||
if let Some(direction) = &request.direction {
|
||||
req = req.query(&[("direction", direction)]);
|
||||
}
|
||||
if let Some(depth) = request.depth {
|
||||
req = req.query(&[("depth", depth.to_string())]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
// Server-defined lineage JSON, returned opaque (the client does not
|
||||
// model the lineage schema; the Python layer deserializes it).
|
||||
rsp.text().await.err_to_http(request_id)
|
||||
}
|
||||
|
||||
async fn explain_refresh_materialized_view(
|
||||
&self,
|
||||
name: &str,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
) -> Result<MvRefreshPlan> {
|
||||
let body = RemoteExplainRefreshRequest {
|
||||
full: Some(full),
|
||||
src_version,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/explain_refresh", name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteExplainRefreshResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(MvRefreshPlan {
|
||||
table_name: body.table_name,
|
||||
has_work: body.has_work,
|
||||
source_version: body.source_version,
|
||||
last_refreshed_version: body.last_refreshed_version,
|
||||
full_refresh: body.full_refresh,
|
||||
rebuild: body.rebuild,
|
||||
units_total: body.units_total,
|
||||
})
|
||||
}
|
||||
|
||||
async fn alter_materialized_view(&self, name: &str, auto_refresh: bool) -> Result<()> {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/alter", name))
|
||||
.json(&RemoteAlterMaterializedViewRequest { auto_refresh });
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_materialized_view(&self, name: &str) -> Result<()> {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/drop", name));
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
|
||||
let req = self.client.get("/v1/materialized_view/list");
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteListMaterializedViewsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body
|
||||
.views
|
||||
.into_iter()
|
||||
.map(|v| MaterializedViewInfo {
|
||||
name: v.name,
|
||||
source_table: v.source_table,
|
||||
projection: v.projection,
|
||||
udf_columns: v.udf_columns,
|
||||
filter: v.filter,
|
||||
auto_refresh: v.auto_refresh,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
|
||||
let req = self.client.get("/v1/job/list");
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteListJobsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.jobs.into_iter().map(JobInfo::from).collect())
|
||||
}
|
||||
|
||||
async fn get_job(&self, job_id: &str, table: Option<&str>) -> Result<Option<JobInfo>> {
|
||||
// Point-access poll path: GET /v1/job/{id}, with the table as the O(1)
|
||||
// hint when known. `query` handles URL-encoding the table name.
|
||||
let mut req = self.client.get(&format!("/v1/job/{job_id}"));
|
||||
if let Some(t) = table {
|
||||
req = req.query(&[("table", t)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteGetJobResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job.map(JobInfo::from))
|
||||
}
|
||||
|
||||
async fn cancel_job(&self, job_id: &str) -> Result<bool> {
|
||||
let req = self.client.post(&format!("/v1/job/{}/cancel", job_id));
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteCancelJobResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.cancelled)
|
||||
}
|
||||
|
||||
async fn job_history(&self, job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
|
||||
let mut req = self.client.get("/v1/job/history");
|
||||
if let Some(j) = job_id {
|
||||
req = req.query(&[("job", j)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteJobHistoryResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.jobs.into_iter().map(JobHistoryInfo::from).collect())
|
||||
}
|
||||
|
||||
async fn errors(&self, job_id: Option<&str>, table: Option<&str>) -> Result<Vec<JobErrorInfo>> {
|
||||
let mut req = self.client.get("/v1/job/errors");
|
||||
if let Some(j) = job_id {
|
||||
req = req.query(&[("job", j)]);
|
||||
}
|
||||
if let Some(t) = table {
|
||||
req = req.query(&[("table", t)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteErrorsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.errors.into_iter().map(JobErrorInfo::from).collect())
|
||||
}
|
||||
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let identifier = build_table_identifier(
|
||||
&request.name,
|
||||
@@ -759,9 +1303,12 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
// Create a RestNamespace pointing to the same remote host with the same authentication headers
|
||||
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
|
||||
.delimiter(&self.client.id_delimiter)
|
||||
// TODO: support header provider
|
||||
.headers(self.namespace_headers.clone());
|
||||
|
||||
if let Some(context_provider) = &self.namespace_context_provider {
|
||||
builder = builder.context_provider(Arc::clone(context_provider));
|
||||
}
|
||||
|
||||
// Apply mTLS configuration if present
|
||||
if let Some(tls_config) = &self.tls_config {
|
||||
if let Some(cert_file) = &tls_config.cert_file {
|
||||
@@ -781,6 +1328,14 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||
if self.namespace_context_provider.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
message:
|
||||
"Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert("uri".to_string(), self.client.host().to_string());
|
||||
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
|
||||
@@ -832,12 +1387,13 @@ impl From<StorageOptions> for RemoteOptions {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::build_cache_key;
|
||||
use super::{NamespaceHeaderProviderContext, build_cache_key};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::{
|
||||
@@ -1490,6 +2046,223 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_derived_compute_routes() {
|
||||
// create_function
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/function/embed/create");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["language"], "python");
|
||||
assert_eq!(body["return_type"], "FLOAT[4]");
|
||||
assert_eq!(body["body"], "def embed(x): ...");
|
||||
assert_eq!(body["options"]["pip"], "torch");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"embed","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.create_function(crate::database::CreateFunctionRequest {
|
||||
name: "embed".into(),
|
||||
language: "python".into(),
|
||||
return_type: "FLOAT[4]".into(),
|
||||
body: "def embed(x): ...".into(),
|
||||
options: [("pip".to_string(), "torch".to_string())].into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// list_functions
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/function/list");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"functions":[{"name":"embed","language":"python","return_type":"Float32","description":""}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let functions = conn.list_functions().await.unwrap();
|
||||
assert_eq!(functions.len(), 1);
|
||||
assert_eq!(functions[0].name, "embed");
|
||||
|
||||
// drop_function
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/function/embed/drop");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"embed","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.drop_function("embed").await.unwrap();
|
||||
|
||||
// create_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/create");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["query"], "SELECT id, embed(body) AS vec FROM docs");
|
||||
assert_eq!(body["auto_refresh"], true);
|
||||
assert_eq!(body["with_no_data"], false);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"mv1","job_id":"j-1"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let mut request = crate::database::CreateMaterializedViewRequest::new(
|
||||
"mv1",
|
||||
"SELECT id, embed(body) AS vec FROM docs",
|
||||
);
|
||||
request.auto_refresh = true;
|
||||
let job_id = conn.create_materialized_view(request).await.unwrap();
|
||||
assert_eq!(job_id.as_deref(), Some("j-1"));
|
||||
|
||||
// refresh_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/refresh");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["num_workers"], 2);
|
||||
assert!(body.get("src_version").is_none());
|
||||
http::Response::builder()
|
||||
.status(202)
|
||||
.body(r#"{"job_id":"j-2"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let mut request = crate::database::RefreshMaterializedViewRequest::new("mv1");
|
||||
request.num_workers = Some(2);
|
||||
let job_id = conn.refresh_materialized_view(request).await.unwrap();
|
||||
assert_eq!(job_id, "j-2");
|
||||
|
||||
// alter_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/alter");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["auto_refresh"], false);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"mv1","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.alter_materialized_view("mv1", false).await.unwrap();
|
||||
|
||||
// drop_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/drop");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"mv1","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.drop_materialized_view("mv1").await.unwrap();
|
||||
|
||||
// list_materialized_views
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/list");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"views":[{"name":"mv1","source_table":"docs","projection":["id"],"udf_columns":["vec=embed(body)"],"filter":null,"auto_refresh":true}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let views = conn.list_materialized_views().await.unwrap();
|
||||
assert_eq!(views.len(), 1);
|
||||
assert_eq!(views[0].source_table, "docs");
|
||||
assert!(views[0].auto_refresh);
|
||||
|
||||
// list_jobs
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/job/list");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"jobs":[{"table":"docs","job_id":"j-3","job_type":"udf_virtual_column_backfill","state":"running","column":"vec","age_seconds":4,"command":null,"units_done":1,"units_total":2,"committed":false,"rows_skipped":0,"error":null}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let jobs = conn.list_jobs().await.unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].state, "running");
|
||||
assert_eq!(jobs[0].units_total, Some(2));
|
||||
|
||||
// cancel_job
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/job/j-3/cancel");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"cancelled":true}"#)
|
||||
.unwrap()
|
||||
});
|
||||
assert!(conn.cancel_job("j-3").await.unwrap());
|
||||
|
||||
// cancel_job: no such inflight job -> false, not an error
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.url().path(), "/v1/job/gone/cancel");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"cancelled":false}"#)
|
||||
.unwrap()
|
||||
});
|
||||
assert!(!conn.cancel_job("gone").await.unwrap());
|
||||
|
||||
// job_history: GET /v1/job/history, no filter
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/job/history");
|
||||
assert!(request.url().query().is_none());
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"jobs":[{"table":"docs","job_id":"j-1","job_type":"udf_virtual_column_backfill","state":"done","column":"vec","created_ms":1000,"updated_ms":2000,"completed_ms":2000,"rows_processed":42,"rows_skipped":3,"error":null,"events":"created\ndone"}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let hist = conn.job_history(None).await.unwrap();
|
||||
assert_eq!(hist.len(), 1);
|
||||
assert_eq!(hist[0].state, "done");
|
||||
assert_eq!(hist[0].rows_processed, Some(42));
|
||||
assert_eq!(hist[0].events.as_deref(), Some("created\ndone"));
|
||||
|
||||
// job_history: ?job= narrows to one job
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.url().path(), "/v1/job/history");
|
||||
assert_eq!(request.url().query(), Some("job=j-1"));
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"jobs":[]}"#)
|
||||
.unwrap()
|
||||
});
|
||||
assert!(conn.job_history(Some("j-1")).await.unwrap().is_empty());
|
||||
|
||||
// errors: GET /v1/job/errors with job + table filters
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/job/errors");
|
||||
assert_eq!(request.url().query(), Some("job=j-1&table=docs"));
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"errors":[{"job_id":"j-1","table":"docs","column":"vec","error_type":"ValueError","error_message":"boom","fragment_id":0,"source_row_id":42,"table_version":7,"age_seconds":5}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let errs = conn.errors(Some("j-1"), Some("docs")).await.unwrap();
|
||||
assert_eq!(errs.len(), 1);
|
||||
assert_eq!(errs[0].error_type, "ValueError");
|
||||
assert_eq!(errs[0].source_row_id, Some(42));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clone_table() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
@@ -1702,6 +2475,75 @@ mod tests {
|
||||
assert!(namespace_client.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_namespace_header_provider_context_maps_headers() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let context_provider = NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>,
|
||||
};
|
||||
|
||||
let context =
|
||||
context_provider.provide_context(&OperationInfo::new("list_tables", "namespace"));
|
||||
|
||||
assert_eq!(
|
||||
context.get("headers.authorization"),
|
||||
Some(&"Bearer token".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_namespace_client_supports_dynamic_headers() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
|_| {
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": []}"#)
|
||||
.unwrap()
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
let namespace_client = conn.namespace_client().await;
|
||||
assert!(namespace_client.is_ok());
|
||||
|
||||
match conn.namespace_client_config().await {
|
||||
Err(Error::NotSupported { message })
|
||||
if message.contains("dynamic headers are configured") => {}
|
||||
Err(err) => panic!("expected NotSupported, got {err:?}"),
|
||||
Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
|
||||
mod rest_adapter_integration {
|
||||
use super::*;
|
||||
|
||||
907
rust/lancedb/src/remote/oauth.rs
Normal file
907
rust/lancedb/src/remote/oauth.rs
Normal file
@@ -0,0 +1,907 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::debug;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::client::HeaderProvider;
|
||||
|
||||
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
const DEFAULT_TOKEN_TTL_SECS: u64 = 3600;
|
||||
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
|
||||
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
|
||||
|
||||
/// OAuth authentication flow configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OAuthFlow {
|
||||
/// Client Credentials grant (service-to-service / M2M).
|
||||
/// Requires `client_secret` in [`OAuthConfig`].
|
||||
ClientCredentials,
|
||||
|
||||
/// Azure Managed Identity via IMDS.
|
||||
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
|
||||
/// IMDS requests bypass proxy settings because the endpoint is link-local.
|
||||
AzureManagedIdentity {
|
||||
/// Client ID for user-assigned managed identity.
|
||||
/// Omit for system-assigned managed identity.
|
||||
client_id: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
///
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
/// Python and TypeScript bindings expose this as a plain config object.
|
||||
#[derive(Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
|
||||
/// Client secret (required for `ClientCredentials`, optional for others).
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
/// OAuth scopes to request.
|
||||
/// For Azure managed identity, exactly one scope or resource is required.
|
||||
/// For example: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
|
||||
/// Authentication flow to use.
|
||||
pub flow: OAuthFlow,
|
||||
|
||||
/// Seconds before token expiry to trigger proactive refresh (default: 300).
|
||||
/// Keep this well below the token TTL; if it is greater than or equal to
|
||||
/// the TTL, each request refreshes the token.
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthConfig")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_deref().map(|_| "<redacted>"),
|
||||
)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("flow", &self.flow)
|
||||
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// -- OIDC Discovery --
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct OidcDiscovery {
|
||||
token_endpoint: String,
|
||||
}
|
||||
|
||||
// -- Token Response --
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
/// Token lifetime in seconds.
|
||||
/// Some providers (Azure IMDS) return this as a string, so we accept both.
|
||||
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TokenResponse {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TokenResponse")
|
||||
.field("access_token", &"<redacted>")
|
||||
.field("expires_in", &self.expires_in)
|
||||
.field("token_type", &self.token_type)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_optional_u64_or_string<'de, D>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<Option<u64>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de;
|
||||
|
||||
struct U64OrString;
|
||||
impl<'de> de::Visitor<'de> for U64OrString {
|
||||
type Value = Option<u64>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("an integer, an integer-valued float, a numeric string, or null")
|
||||
}
|
||||
|
||||
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
|
||||
Ok(Some(v))
|
||||
}
|
||||
|
||||
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
|
||||
if v < 0 {
|
||||
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||
}
|
||||
Ok(Some(v as u64))
|
||||
}
|
||||
|
||||
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
|
||||
if !v.is_finite() || v < 0.0 || v.fract() != 0.0 || v > u64::MAX as f64 {
|
||||
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||
}
|
||||
Ok(Some(v as u64))
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
|
||||
v.parse::<u64>().map(Some).map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(U64OrString)
|
||||
}
|
||||
|
||||
// -- Internal Token State --
|
||||
|
||||
struct TokenState {
|
||||
access_token: Option<String>,
|
||||
expires_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl TokenState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
access_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_expired(&self, buffer: Duration) -> bool {
|
||||
match (self.access_token.as_ref(), self.expires_at) {
|
||||
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
|
||||
(None, _) => true,
|
||||
(Some(_), None) => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, resp: &TokenResponse) {
|
||||
self.access_token = Some(resp.access_token.clone());
|
||||
let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS);
|
||||
self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in));
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait TokenSource: Send + Sync + std::fmt::Debug {
|
||||
async fn fetch_token(&self) -> Result<TokenResponse>;
|
||||
}
|
||||
|
||||
struct ClientCredentialsSource {
|
||||
issuer_url: String,
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
scopes: Vec<String>,
|
||||
http_client: Client,
|
||||
discovery: RwLock<Option<OidcDiscovery>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ClientCredentialsSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ClientCredentialsSource")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field("client_secret", &"<redacted>")
|
||||
.field("scopes", &self.scopes)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentialsSource {
|
||||
fn new(
|
||||
issuer_url: String,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
scopes: Vec<String>,
|
||||
) -> Result<Self> {
|
||||
let client_secret = client_secret.ok_or(Error::InvalidInput {
|
||||
message: "client_secret is required for ClientCredentials flow".to_string(),
|
||||
})?;
|
||||
Self::validate_issuer_transport(&issuer_url)?;
|
||||
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create HTTP client for OAuth: {e}"),
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
issuer_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
scopes,
|
||||
http_client,
|
||||
discovery: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_issuer_transport(issuer_url: &str) -> Result<()> {
|
||||
let issuer = url::Url::parse(issuer_url).map_err(|e| Error::InvalidInput {
|
||||
message: format!("Invalid OAuth issuer_url: {e}"),
|
||||
})?;
|
||||
|
||||
match issuer.scheme() {
|
||||
"https" => Ok(()),
|
||||
"http" if Self::is_loopback_issuer(&issuer) => Ok(()),
|
||||
_ => Err(Error::InvalidInput {
|
||||
message:
|
||||
"ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||
.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loopback_issuer(issuer: &url::Url) -> bool {
|
||||
let Some(host) = issuer.host_str() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
host.eq_ignore_ascii_case("localhost")
|
||||
|| host
|
||||
.parse::<IpAddr>()
|
||||
.map(|addr| addr.is_loopback())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn get_discovery(&self) -> Result<OidcDiscovery> {
|
||||
{
|
||||
let cached = self.discovery.read().await;
|
||||
if let Some(ref disc) = *cached {
|
||||
return Ok(disc.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = self.discovery.write().await;
|
||||
// Double-check
|
||||
if let Some(ref disc) = *cache {
|
||||
return Ok(disc.clone());
|
||||
}
|
||||
|
||||
let discovery_url = format!(
|
||||
"{}/.well-known/openid-configuration",
|
||||
self.issuer_url.trim_end_matches('/')
|
||||
);
|
||||
|
||||
debug!("Fetching OIDC discovery from {}", discovery_url);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to fetch OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"OIDC discovery failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
let result = disc.clone();
|
||||
|
||||
*cache = Some(disc);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn get_token_endpoint(&self) -> Result<String> {
|
||||
self.get_discovery().await.map(|disc| disc.token_endpoint)
|
||||
}
|
||||
|
||||
fn scopes_string(&self) -> String {
|
||||
self.scopes.join(" ")
|
||||
}
|
||||
|
||||
async fn post_token_request(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
params: &[(&str, &str)],
|
||||
) -> Result<TokenResponse> {
|
||||
let resp = self
|
||||
.http_client
|
||||
.post(endpoint)
|
||||
.form(params)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Token request to {endpoint} failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Token request failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse token response: {e}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenSource for ClientCredentialsSource {
|
||||
async fn fetch_token(&self) -> Result<TokenResponse> {
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
let scope = self.scopes_string();
|
||||
let params = [
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", self.client_id.as_str()),
|
||||
("client_secret", self.client_secret.as_str()),
|
||||
("scope", scope.as_str()),
|
||||
];
|
||||
|
||||
self.post_token_request(&token_endpoint, ¶ms).await
|
||||
}
|
||||
}
|
||||
|
||||
struct AzureImdsSource {
|
||||
client_id: Option<String>,
|
||||
resource: String,
|
||||
http_client: Client,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AzureImdsSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AzureImdsSource")
|
||||
.field("client_id", &self.client_id)
|
||||
.field("resource", &self.resource)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AzureImdsSource {
|
||||
fn new(scopes: Vec<String>, client_id: Option<String>) -> Result<Self> {
|
||||
let resource = Self::resource_from_scopes(&scopes)?;
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.no_proxy()
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create HTTP client for Azure IMDS OAuth: {e}"),
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
client_id,
|
||||
resource,
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
fn resource_from_scopes(scopes: &[String]) -> Result<String> {
|
||||
let [scope] = scopes else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
|
||||
.to_string(),
|
||||
});
|
||||
};
|
||||
|
||||
Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenSource for AzureImdsSource {
|
||||
async fn fetch_token(&self) -> Result<TokenResponse> {
|
||||
let mut url = format!(
|
||||
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
|
||||
urlencoding::encode(&self.resource),
|
||||
);
|
||||
if let Some(cid) = self.client_id.as_deref() {
|
||||
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Metadata", "true")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Azure IMDS request failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Azure IMDS returned status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse IMDS token response: {e}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth header provider that manages the full token lifecycle.
|
||||
///
|
||||
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
|
||||
/// headers into every LanceDB request, with automatic token refresh.
|
||||
pub struct OAuthHeaderProvider {
|
||||
token_source: Box<dyn TokenSource>,
|
||||
token_state: Arc<RwLock<TokenState>>,
|
||||
refresh_buffer: Duration,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthHeaderProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthHeaderProvider")
|
||||
.field("token_source", &self.token_source)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthHeaderProvider {
|
||||
/// Create a new OAuth header provider from configuration.
|
||||
pub fn new(config: OAuthConfig) -> Result<Self> {
|
||||
let OAuthConfig {
|
||||
issuer_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
scopes,
|
||||
flow,
|
||||
refresh_buffer_secs,
|
||||
} = config;
|
||||
|
||||
if scopes.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "At least one OAuth scope is required".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let refresh_buffer =
|
||||
Duration::from_secs(refresh_buffer_secs.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS));
|
||||
let token_source: Box<dyn TokenSource> = match flow {
|
||||
OAuthFlow::ClientCredentials => Box::new(ClientCredentialsSource::new(
|
||||
issuer_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
scopes,
|
||||
)?),
|
||||
OAuthFlow::AzureManagedIdentity { client_id } => {
|
||||
Box::new(AzureImdsSource::new(scopes, client_id)?)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
token_source,
|
||||
token_state: Arc::new(RwLock::new(TokenState::new())),
|
||||
refresh_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing if necessary.
|
||||
async fn get_valid_token(&self) -> Result<String> {
|
||||
// Fast path: check if current token is still valid
|
||||
{
|
||||
let state = self.token_state.read().await;
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: acquire or refresh token
|
||||
let mut state = self.token_state.write().await;
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
debug!("Acquiring new OAuth token via {:?}", self.token_source);
|
||||
let resp = self.token_source.fetch_token().await?;
|
||||
|
||||
state.update(&resp);
|
||||
Ok(resp.access_token)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HeaderProvider for OAuthHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
let token = self.get_valid_token().await?;
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
format!("Bearer {token}"),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[test]
|
||||
fn test_token_state_expiry() {
|
||||
let mut state = TokenState::new();
|
||||
assert!(state.is_expired(Duration::from_secs(0)));
|
||||
|
||||
state.access_token = Some("tok".to_string());
|
||||
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
|
||||
assert!(!state.is_expired(Duration::from_secs(300)));
|
||||
assert!(state.is_expired(Duration::from_secs(601)));
|
||||
|
||||
state.expires_at = None;
|
||||
assert!(state.is_expired(Duration::from_secs(0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_state_uses_default_expiry() {
|
||||
let mut state = TokenState::new();
|
||||
let response = TokenResponse {
|
||||
access_token: "tok".to_string(),
|
||||
expires_in: None,
|
||||
token_type: None,
|
||||
};
|
||||
|
||||
state.update(&response);
|
||||
|
||||
assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1)));
|
||||
assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_accepts_float_expires_in() {
|
||||
let response: TokenResponse =
|
||||
serde_json::from_str(r#"{"access_token":"tok","expires_in":3600.0}"#).unwrap();
|
||||
|
||||
assert_eq!(response.expires_in, Some(3600));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_rejects_negative_expires_in() {
|
||||
let err =
|
||||
serde_json::from_str::<TokenResponse>(r#"{"access_token":"tok","expires_in":-1}"#)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("invalid expires_in value: -1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_debug_redacts_access_token() {
|
||||
let response = TokenResponse {
|
||||
access_token: "secret-token".to_string(),
|
||||
expires_in: Some(3600),
|
||||
token_type: Some("Bearer".to_string()),
|
||||
};
|
||||
|
||||
let debug = format!("{response:?}");
|
||||
assert!(!debug.contains("secret-token"));
|
||||
assert!(debug.contains("access_token: \"<redacted>\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scopes_string() {
|
||||
let source = ClientCredentialsSource::new(
|
||||
"https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
"app-id".to_string(),
|
||||
Some("secret".to_string()),
|
||||
vec!["scope1".to_string(), "scope2".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(source.scopes_string(), "scope1 scope2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_config_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let debug = format!("{config:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_header_provider_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
let debug = format!("{provider:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: \"<redacted>\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_resource_from_default_scope() {
|
||||
assert_eq!(
|
||||
AzureImdsSource::resource_from_scopes(&["api://test/.default".to_string()]).unwrap(),
|
||||
"api://test"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_resource_without_default_suffix() {
|
||||
assert_eq!(
|
||||
AzureImdsSource::resource_from_scopes(&["api://test".to_string()]).unwrap(),
|
||||
"api://test"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_rejects_multiple_scopes() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec![
|
||||
"api://test-a/.default".to_string(),
|
||||
"api://test-b/.default".to_string(),
|
||||
],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_endpoint_requires_discovery_success() {
|
||||
let (issuer_url, server) = spawn_discovery_error_server().await;
|
||||
let source = ClientCredentialsSource::new(
|
||||
issuer_url,
|
||||
"client-id".to_string(),
|
||||
Some("secret".to_string()),
|
||||
vec!["scope".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = source.get_token_endpoint().await.unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::Runtime { message }
|
||||
if message.contains("OIDC discovery failed with status 503")
|
||||
));
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_credentials_requires_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_credentials_rejects_insecure_non_loopback_issuer() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "http://issuer.example.com".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = OAuthHeaderProvider::new(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_scopes_rejected() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec![],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_credentials_token_lifecycle() {
|
||||
let (issuer_url, token_requests, server) = spawn_oauth_server().await;
|
||||
let config = OAuthConfig {
|
||||
issuer_url,
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: Some(0),
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||
|
||||
provider.token_state.write().await.expires_at =
|
||||
Some(Instant::now() - Duration::from_secs(1));
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 2);
|
||||
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
async fn spawn_oauth_server() -> (String, Arc<AtomicUsize>, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let issuer_url = format!("http://{addr}");
|
||||
let token_requests = Arc::new(AtomicUsize::new(0));
|
||||
let server_token_requests = Arc::clone(&token_requests);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
for _ in 0..3 {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let (request_line, body) = read_http_request(&mut stream).await;
|
||||
|
||||
if request_line.starts_with("GET /.well-known/openid-configuration ") {
|
||||
let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#);
|
||||
write_json_response(&mut stream, "200 OK", &discovery).await;
|
||||
} else if request_line.starts_with("POST /token ") {
|
||||
assert!(body.contains("grant_type=client_credentials"));
|
||||
assert!(body.contains("client_id=client-id"));
|
||||
assert!(body.contains("client_secret=secret"));
|
||||
assert!(body.contains("scope=scope"));
|
||||
|
||||
let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let token = format!(
|
||||
r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"#
|
||||
);
|
||||
write_json_response(&mut stream, "200 OK", &token).await;
|
||||
} else {
|
||||
write_json_response(&mut stream, "404 Not Found", "{}").await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(issuer_url, token_requests, server)
|
||||
}
|
||||
|
||||
async fn spawn_discovery_error_server() -> (String, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let issuer_url = format!("http://{addr}");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let (request_line, _) = read_http_request(&mut stream).await;
|
||||
assert!(request_line.starts_with("GET /.well-known/openid-configuration "));
|
||||
write_json_response(&mut stream, "503 Service Unavailable", "{}").await;
|
||||
});
|
||||
|
||||
(issuer_url, server)
|
||||
}
|
||||
|
||||
async fn read_http_request(stream: &mut TcpStream) -> (String, String) {
|
||||
let mut buffer = Vec::new();
|
||||
let mut header_end = None;
|
||||
|
||||
while header_end.is_none() {
|
||||
let mut chunk = [0; 1024];
|
||||
let read = stream.read(&mut chunk).await.unwrap();
|
||||
assert_ne!(read, 0, "connection closed before request headers");
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4);
|
||||
}
|
||||
|
||||
let header_end = header_end.unwrap();
|
||||
let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
|
||||
let request_line = headers.lines().next().unwrap_or_default().to_string();
|
||||
let content_length = headers
|
||||
.lines()
|
||||
.find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
name.eq_ignore_ascii_case("content-length")
|
||||
.then(|| value.trim().parse::<usize>().ok())
|
||||
.flatten()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
while buffer.len() < header_end + content_length {
|
||||
let mut chunk = [0; 1024];
|
||||
let read = stream.read(&mut chunk).await.unwrap();
|
||||
assert_ne!(read, 0, "connection closed before request body");
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
let body =
|
||||
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
|
||||
|
||||
(request_line, body)
|
||||
}
|
||||
|
||||
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
haystack
|
||||
.windows(needle.len())
|
||||
.position(|window| window == needle)
|
||||
}
|
||||
|
||||
async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) {
|
||||
let response = format!(
|
||||
"HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream.write_all(response.as_bytes()).await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -2309,6 +2309,126 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
message: "optimize is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn add_computed_columns(
|
||||
&self,
|
||||
columns: &[(String, String)],
|
||||
expression: &str,
|
||||
) -> Result<()> {
|
||||
let new_columns: Vec<serde_json::Value> = columns
|
||||
.iter()
|
||||
.map(|(name, data_type)| {
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"computed": { "data_type": data_type, "expression": expression },
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/add_columns/", self.identifier))
|
||||
.json(&serde_json::json!({ "new_columns": new_columns }));
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn refresh_column(
|
||||
&self,
|
||||
columns: &[String],
|
||||
where_clause: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> Result<String> {
|
||||
let mut body = serde_json::json!({ "columns": columns });
|
||||
if let Some(w) = where_clause {
|
||||
body["where_clause"] = serde_json::Value::String(w);
|
||||
}
|
||||
if let Some(n) = num_workers {
|
||||
body["num_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = max_workers {
|
||||
body["max_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = batch_size {
|
||||
body["batch_size"] = n.into();
|
||||
}
|
||||
if let Some(p) = priority {
|
||||
body["priority"] = serde_json::Value::String(p);
|
||||
}
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/refresh_column", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RefreshColumnResponse {
|
||||
job_id: String,
|
||||
}
|
||||
let body: RefreshColumnResponse = response.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn load_columns(&self, request: crate::table::LoadColumnsRequest) -> Result<String> {
|
||||
let columns: Vec<serde_json::Value> = request
|
||||
.columns
|
||||
.iter()
|
||||
.map(|(target, source)| {
|
||||
serde_json::json!({
|
||||
"target": target,
|
||||
"source": source.clone().unwrap_or_else(|| target.clone()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mut source = serde_json::json!({
|
||||
"uris": request.source_uris,
|
||||
"format": request.source_format,
|
||||
});
|
||||
if let Some(opts) = request.source_storage_options {
|
||||
source["storage_options"] = serde_json::to_value(opts).unwrap_or_default();
|
||||
}
|
||||
let mut body = serde_json::json!({
|
||||
"columns": columns,
|
||||
"source": source,
|
||||
"target_key": request.target_key,
|
||||
});
|
||||
if let Some(k) = request.source_key {
|
||||
body["source_key"] = serde_json::Value::String(k);
|
||||
}
|
||||
if let Some(m) = request.on_missing {
|
||||
body["on_missing"] = serde_json::Value::String(m);
|
||||
}
|
||||
if let Some(n) = request.num_workers {
|
||||
body["num_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = request.max_workers {
|
||||
body["max_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = request.batch_size {
|
||||
body["batch_size"] = n.into();
|
||||
}
|
||||
if let Some(n) = request.commit_granularity {
|
||||
body["commit_granularity"] = n.into();
|
||||
}
|
||||
if let Some(p) = request.priority {
|
||||
body["priority"] = serde_json::Value::String(p);
|
||||
}
|
||||
let http_request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/load_columns", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(http_request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
#[derive(serde::Deserialize)]
|
||||
struct LoadColumnsResponse {
|
||||
job_id: String,
|
||||
}
|
||||
let body: LoadColumnsResponse = response.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn add_columns(
|
||||
&self,
|
||||
transforms: NewColumnTransform,
|
||||
@@ -2801,6 +2921,75 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_column() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/refresh_column");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["columns"], serde_json::json!(["vec"]));
|
||||
assert_eq!(body["num_workers"], 2);
|
||||
assert!(body.get("where_clause").is_none());
|
||||
|
||||
http::Response::builder()
|
||||
.status(202)
|
||||
.body(r#"{"job_id":"j-9"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let job_id = table
|
||||
.refresh_column(&["vec".to_string()], None, Some(2), None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(job_id, "j-9");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_columns() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/load_columns");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(
|
||||
body["columns"],
|
||||
serde_json::json!([{"target": "embedding", "source": "emb"}])
|
||||
);
|
||||
assert_eq!(body["source"]["format"], "parquet");
|
||||
assert_eq!(
|
||||
body["source"]["uris"],
|
||||
serde_json::json!(["s3://b/x.parquet"])
|
||||
);
|
||||
assert_eq!(body["target_key"], "document_id");
|
||||
assert_eq!(body["source_key"], "doc_id");
|
||||
assert_eq!(body["on_missing"], "null");
|
||||
assert_eq!(body["num_workers"], 4);
|
||||
|
||||
http::Response::builder()
|
||||
.status(202)
|
||||
.body(r#"{"job_id":"lc-7"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let request = crate::table::LoadColumnsRequest {
|
||||
source_uris: vec!["s3://b/x.parquet".to_string()],
|
||||
source_format: "parquet".to_string(),
|
||||
source_storage_options: None,
|
||||
target_key: "document_id".to_string(),
|
||||
source_key: Some("doc_id".to_string()),
|
||||
columns: vec![("embedding".to_string(), Some("emb".to_string()))],
|
||||
on_missing: Some("null".to_string()),
|
||||
num_workers: Some(4),
|
||||
max_workers: None,
|
||||
batch_size: None,
|
||||
commit_granularity: None,
|
||||
priority: None,
|
||||
};
|
||||
let job_id = table.load_columns(request).await.unwrap();
|
||||
assert_eq!(job_id, "lc-7");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_version() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
|
||||
@@ -471,6 +471,33 @@ impl LsmWriteSpec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to fill existing table columns from an external source by
|
||||
/// primary-key join (Geneva `Table.load_columns()` parity). Server-backed
|
||||
/// feature (LanceDB Enterprise / Cloud).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadColumnsRequest {
|
||||
/// External source URIs.
|
||||
pub source_uris: Vec<String>,
|
||||
/// Source format: "parquet" | "lance" | "ipc".
|
||||
pub source_format: String,
|
||||
/// Source-only storage options (e.g. cloud credentials).
|
||||
pub source_storage_options: Option<HashMap<String, String>>,
|
||||
/// Destination primary-key column.
|
||||
pub target_key: String,
|
||||
/// Source primary-key column. Defaults to `target_key` when None.
|
||||
pub source_key: Option<String>,
|
||||
/// Value column mappings as `(target, source)`; a None source defaults to
|
||||
/// the target name.
|
||||
pub columns: Vec<(String, Option<String>)>,
|
||||
/// Missing-row policy: "carry" (default) | "null" | "error".
|
||||
pub on_missing: Option<String>,
|
||||
pub num_workers: Option<u32>,
|
||||
pub max_workers: Option<u32>,
|
||||
pub batch_size: Option<u32>,
|
||||
pub commit_granularity: Option<u32>,
|
||||
pub priority: Option<String>,
|
||||
}
|
||||
|
||||
/// A trait for anything "table-like". This is used for both native tables (which target
|
||||
/// Lance datasets) and remote tables (which target LanceDB cloud)
|
||||
///
|
||||
@@ -620,6 +647,47 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
transforms: NewColumnTransform,
|
||||
read_columns: Option<Vec<String>>,
|
||||
) -> Result<AddColumnsResult>;
|
||||
/// Declare computed columns bound to a registered function: each
|
||||
/// `(name, sql_type)` is added all-null with the expression stored
|
||||
/// as its binding; no compute happens here (the server's lazy
|
||||
/// detector or refresh_column fills them). Several columns map a
|
||||
/// struct-returning function's fields positionally. Server-backed
|
||||
/// feature; the default returns NotSupported.
|
||||
async fn add_computed_columns(
|
||||
&self,
|
||||
_columns: &[(String, String)],
|
||||
_expression: &str,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "computed columns are not supported by this table".into(),
|
||||
})
|
||||
}
|
||||
/// Trigger recompute of computed columns. The expression is
|
||||
/// resolved server-side from each column's stored binding; columns
|
||||
/// bound to the same struct-returning function refresh together.
|
||||
/// Returns the refresh job id. Server-backed feature (LanceDB
|
||||
/// Enterprise / Cloud); the default returns NotSupported.
|
||||
async fn refresh_column(
|
||||
&self,
|
||||
_columns: &[String],
|
||||
_where_clause: Option<String>,
|
||||
_num_workers: Option<u32>,
|
||||
_max_workers: Option<u32>,
|
||||
_batch_size: Option<u32>,
|
||||
_priority: Option<String>,
|
||||
) -> Result<String> {
|
||||
Err(Error::NotSupported {
|
||||
message: "refresh_column is not supported by this table".into(),
|
||||
})
|
||||
}
|
||||
/// Fill existing columns from an external source by primary-key join
|
||||
/// (Geneva `load_columns`). Returns the load job id. Server-backed feature;
|
||||
/// the default returns NotSupported.
|
||||
async fn load_columns(&self, _request: LoadColumnsRequest) -> Result<String> {
|
||||
Err(Error::NotSupported {
|
||||
message: "load_columns is not supported by this table".into(),
|
||||
})
|
||||
}
|
||||
/// Alter columns in the table.
|
||||
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<AlterColumnsResult>;
|
||||
/// Drop columns from the table.
|
||||
@@ -1461,6 +1529,48 @@ impl Table {
|
||||
self.inner.add_columns(transforms, read_columns).await
|
||||
}
|
||||
|
||||
/// Declare computed columns bound to a registered function
|
||||
/// (`(name, sql_type)` pairs + a `f(args)` expression). No compute
|
||||
/// happens here. Server-backed feature.
|
||||
pub async fn add_computed_columns(
|
||||
&self,
|
||||
columns: &[(String, String)],
|
||||
expression: &str,
|
||||
) -> Result<()> {
|
||||
self.inner.add_computed_columns(columns, expression).await
|
||||
}
|
||||
|
||||
/// Trigger recompute of computed columns (REFRESH COLUMN). The
|
||||
/// expression comes from each column's stored binding; columns
|
||||
/// bound to the same struct-returning function refresh together.
|
||||
/// Returns the refresh job id. Server-backed feature.
|
||||
pub async fn refresh_column(
|
||||
&self,
|
||||
columns: &[String],
|
||||
where_clause: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> Result<String> {
|
||||
self.inner
|
||||
.refresh_column(
|
||||
columns,
|
||||
where_clause,
|
||||
num_workers,
|
||||
max_workers,
|
||||
batch_size,
|
||||
priority,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fill existing columns from an external Parquet/Lance/IPC source by
|
||||
/// primary-key join (Geneva `Table.load_columns()`). Returns the job id.
|
||||
pub async fn load_columns(&self, request: LoadColumnsRequest) -> Result<String> {
|
||||
self.inner.load_columns(request).await
|
||||
}
|
||||
|
||||
/// Change a column's name or nullability.
|
||||
pub async fn alter_columns(
|
||||
&self,
|
||||
|
||||
@@ -579,24 +579,45 @@ fn array_to_f32_vec(arr: &Arc<dyn arrow_array::Array>) -> Result<Vec<f32>> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Magic bytes that prefix (and suffix) the Arrow IPC *file* format.
|
||||
const ARROW_IPC_FILE_MAGIC: &[u8] = b"ARROW1";
|
||||
|
||||
/// Parse Arrow IPC response from the namespace server.
|
||||
///
|
||||
/// The server may return either the Arrow IPC *file* format or the *stream*
|
||||
/// format. REST/phalanx returns the file format (it begins with the `ARROW1`
|
||||
/// magic); reading that with a `StreamReader` fails with "failed to fill whole
|
||||
/// buffer". Detect the magic and pick the matching reader so both are handled.
|
||||
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use arrow_ipc::reader::{FileReader, StreamReader};
|
||||
use std::io::Cursor;
|
||||
|
||||
let cursor = Cursor::new(bytes);
|
||||
let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse Arrow IPC response: {}", e),
|
||||
})?;
|
||||
|
||||
// Collect all record batches
|
||||
let schema = reader.schema();
|
||||
let batches: Vec<_> = reader
|
||||
.into_iter()
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read Arrow IPC batches: {}", e),
|
||||
let (schema, batches) = if bytes.starts_with(ARROW_IPC_FILE_MAGIC) {
|
||||
let reader = FileReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse Arrow IPC file response: {}", e),
|
||||
})?;
|
||||
let schema = reader.schema();
|
||||
let batches = reader
|
||||
.into_iter()
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read Arrow IPC file batches: {}", e),
|
||||
})?;
|
||||
(schema, batches)
|
||||
} else {
|
||||
let reader =
|
||||
StreamReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse Arrow IPC response: {}", e),
|
||||
})?;
|
||||
let schema = reader.schema();
|
||||
let batches = reader
|
||||
.into_iter()
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read Arrow IPC batches: {}", e),
|
||||
})?;
|
||||
(schema, batches)
|
||||
};
|
||||
|
||||
// Create a stream from the batches
|
||||
let stream = futures::stream::iter(batches.into_iter().map(Ok));
|
||||
@@ -624,6 +645,59 @@ mod tests {
|
||||
FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_arrow_ipc_response_handles_file_and_stream() {
|
||||
use arrow_array::{Int32Array, RecordBatch};
|
||||
use arrow_ipc::writer::{FileWriter, StreamWriter};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Arrow IPC *file* format -- what REST/phalanx returns. Previously this
|
||||
// failed with "failed to fill whole buffer" because we used a StreamReader.
|
||||
let mut file_buf = Vec::new();
|
||||
{
|
||||
let mut writer = FileWriter::try_new(&mut file_buf, &schema).unwrap();
|
||||
writer.write(&batch).unwrap();
|
||||
writer.finish().unwrap();
|
||||
}
|
||||
assert!(file_buf.starts_with(ARROW_IPC_FILE_MAGIC));
|
||||
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(file_buf))
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|b| b.num_rows())
|
||||
.sum();
|
||||
assert_eq!(rows, 3);
|
||||
|
||||
// Arrow IPC *stream* format must still parse.
|
||||
let mut stream_buf = Vec::new();
|
||||
{
|
||||
let mut writer = StreamWriter::try_new(&mut stream_buf, &schema).unwrap();
|
||||
writer.write(&batch).unwrap();
|
||||
writer.finish().unwrap();
|
||||
}
|
||||
assert!(!stream_buf.starts_with(ARROW_IPC_FILE_MAGIC));
|
||||
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(stream_buf))
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|b| b.num_rows())
|
||||
.sum();
|
||||
assert_eq!(rows, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_to_namespace_query_vector() {
|
||||
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));
|
||||
|
||||
Reference in New Issue
Block a user