mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-28 08:30:39 +00:00
Compare commits
11 Commits
python-v0.
...
jack/pytho
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
956a8ee714 | ||
|
|
3df3043563 | ||
|
|
8a5cd74e48 | ||
|
|
448d5ec20f | ||
|
|
8718345229 | ||
|
|
026fedc286 | ||
|
|
fe287dc98c | ||
|
|
411568b72c | ||
|
|
ebf8d55ede | ||
|
|
0ba70d96c3 | ||
|
|
0749532c3c |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.31.0-beta.1"
|
current_version = "0.31.0-beta.3"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
97
Cargo.lock
generated
97
Cargo.lock
generated
@@ -3432,8 +3432,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fsst"
|
name = "fsst"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"rand 0.9.4",
|
"rand 0.9.4",
|
||||||
@@ -4735,8 +4735,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance"
|
name = "lance"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arc-swap",
|
"arc-swap",
|
||||||
"arrow",
|
"arrow",
|
||||||
@@ -4771,7 +4771,7 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"half",
|
"half",
|
||||||
"humantime",
|
"humantime",
|
||||||
"itertools 0.13.0",
|
"itertools 0.14.0",
|
||||||
"lance-arrow",
|
"lance-arrow",
|
||||||
"lance-core",
|
"lance-core",
|
||||||
"lance-datafusion",
|
"lance-datafusion",
|
||||||
@@ -4810,8 +4810,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-arrow"
|
name = "lance-arrow"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -4832,7 +4832,7 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-arrow-scalar"
|
name = "lance-arrow-scalar"
|
||||||
version = "58.0.0"
|
version = "58.0.0"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -4846,7 +4846,7 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-arrow-stats"
|
name = "lance-arrow-stats"
|
||||||
version = "58.0.0"
|
version = "58.0.0"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-schema",
|
"arrow-schema",
|
||||||
@@ -4855,8 +4855,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-bitpacking"
|
name = "lance-bitpacking"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrayref",
|
"arrayref",
|
||||||
"paste",
|
"paste",
|
||||||
@@ -4865,8 +4865,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-core"
|
name = "lance-core"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -4878,7 +4878,7 @@ dependencies = [
|
|||||||
"datafusion-common",
|
"datafusion-common",
|
||||||
"datafusion-sql",
|
"datafusion-sql",
|
||||||
"futures",
|
"futures",
|
||||||
"itertools 0.13.0",
|
"itertools 0.14.0",
|
||||||
"lance-arrow",
|
"lance-arrow",
|
||||||
"lance-derive",
|
"lance-derive",
|
||||||
"libc",
|
"libc",
|
||||||
@@ -4904,8 +4904,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-datafusion"
|
name = "lance-datafusion"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4935,8 +4935,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-datagen"
|
name = "lance-datagen"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4953,8 +4953,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-derive"
|
name = "lance-derive"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@@ -4963,8 +4963,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-encoding"
|
name = "lance-encoding"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4980,7 +4980,7 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"hex",
|
"hex",
|
||||||
"hyperloglogplus",
|
"hyperloglogplus",
|
||||||
"itertools 0.13.0",
|
"itertools 0.14.0",
|
||||||
"lance-arrow",
|
"lance-arrow",
|
||||||
"lance-bitpacking",
|
"lance-bitpacking",
|
||||||
"lance-core",
|
"lance-core",
|
||||||
@@ -4999,8 +4999,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-file"
|
name = "lance-file"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -5030,8 +5030,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-index"
|
name = "lance-index"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arc-swap",
|
"arc-swap",
|
||||||
"arrow",
|
"arrow",
|
||||||
@@ -5056,7 +5056,7 @@ dependencies = [
|
|||||||
"fst",
|
"fst",
|
||||||
"futures",
|
"futures",
|
||||||
"half",
|
"half",
|
||||||
"itertools 0.13.0",
|
"itertools 0.14.0",
|
||||||
"jieba-rs",
|
"jieba-rs",
|
||||||
"jsonb",
|
"jsonb",
|
||||||
"lance-arrow",
|
"lance-arrow",
|
||||||
@@ -5096,8 +5096,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-io"
|
name = "lance-io"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
@@ -5138,8 +5138,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-linalg"
|
name = "lance-linalg"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -5155,8 +5155,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-namespace"
|
name = "lance-namespace"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@@ -5168,8 +5168,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-namespace-impls"
|
name = "lance-namespace-impls"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-ipc",
|
"arrow-ipc",
|
||||||
@@ -5223,15 +5223,15 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-select"
|
name = "lance-select"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
"arrow-schema",
|
"arrow-schema",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"bytes",
|
"bytes",
|
||||||
"itertools 0.13.0",
|
"itertools 0.14.0",
|
||||||
"lance-core",
|
"lance-core",
|
||||||
"roaring",
|
"roaring",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -5239,8 +5239,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-table"
|
name = "lance-table"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -5279,8 +5279,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-testing"
|
name = "lance-testing"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-schema",
|
"arrow-schema",
|
||||||
@@ -5293,8 +5293,8 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-tokenizer"
|
name = "lance-tokenizer"
|
||||||
version = "9.0.0-beta.2"
|
version = "9.0.0-beta.8"
|
||||||
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
|
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"icu_segmenter",
|
"icu_segmenter",
|
||||||
"jieba-rs",
|
"jieba-rs",
|
||||||
@@ -5307,7 +5307,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.31.0-beta.1"
|
version = "0.31.0-beta.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
@@ -5384,13 +5384,14 @@ dependencies = [
|
|||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"url",
|
"url",
|
||||||
|
"urlencoding",
|
||||||
"uuid",
|
"uuid",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
version = "0.31.0-beta.1"
|
version = "0.31.0-beta.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -5415,7 +5416,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.34.0-beta.1"
|
version = "0.34.0-beta.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
|||||||
28
Cargo.toml
28
Cargo.toml
@@ -13,20 +13,20 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.91.0"
|
rust-version = "1.91.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance = { "version" = "=9.0.0-beta.8", default-features = false, "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-core = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-core = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-datagen = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-datagen = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-file = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-file = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-io = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-io = { "version" = "=9.0.0-beta.8", default-features = false, "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-index = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-index = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-linalg = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-linalg = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-namespace = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-namespace = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-namespace-impls = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-namespace-impls = { "version" = "=9.0.0-beta.8", default-features = false, "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-table = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-table = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-testing = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-testing = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-datafusion = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-datafusion = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-encoding = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-encoding = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-arrow = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
lance-arrow = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
ahash = "0.8"
|
ahash = "0.8"
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "58.0.0", optional = false }
|
arrow = { version = "58.0.0", optional = false }
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-core</artifactId>
|
<artifactId>lancedb-core</artifactId>
|
||||||
<version>0.31.0-beta.1</version>
|
<version>0.31.0-beta.3</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.31.0-beta.1</version>
|
<version>0.31.0-beta.3</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.31.0-beta.1</version>
|
<version>0.31.0-beta.3</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
<description>LanceDB Java SDK Parent POM</description>
|
<description>LanceDB Java SDK Parent POM</description>
|
||||||
@@ -28,7 +28,7 @@
|
|||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<arrow.version>15.0.0</arrow.version>
|
<arrow.version>15.0.0</arrow.version>
|
||||||
<lance-core.version>9.0.0-beta.2</lance-core.version>
|
<lance-core.version>9.0.0-beta.8</lance-core.version>
|
||||||
<spotless.skip>false</spotless.skip>
|
<spotless.skip>false</spotless.skip>
|
||||||
<spotless.version>2.30.0</spotless.version>
|
<spotless.version>2.30.0</spotless.version>
|
||||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.31.0-beta.1"
|
version = "0.31.0-beta.3"
|
||||||
publish = false
|
publish = false
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.31.0-beta.1",
|
"version": "0.31.0-beta.3",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.34.0-beta.2"
|
current_version = "0.34.0-beta.3"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.34.0-beta.2"
|
version = "0.34.0-beta.3"
|
||||||
publish = false
|
publish = false
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
|
|||||||
@@ -89,6 +89,8 @@ def connect(
|
|||||||
If presented, connect to LanceDB cloud.
|
If presented, connect to LanceDB cloud.
|
||||||
Otherwise, connect to a database on file system or cloud storage.
|
Otherwise, connect to a database on file system or cloud storage.
|
||||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||||
|
OAuth configuration is currently supported only by ``connect_async``;
|
||||||
|
synchronous LanceDB Cloud connections require an API key.
|
||||||
region: str, default "us-east-1"
|
region: str, default "us-east-1"
|
||||||
The region to use for LanceDB Cloud.
|
The region to use for LanceDB Cloud.
|
||||||
host_override: str, optional
|
host_override: str, optional
|
||||||
@@ -340,6 +342,7 @@ async def connect_async(
|
|||||||
session: Optional[Session] = None,
|
session: Optional[Session] = None,
|
||||||
manifest_enabled: bool = False,
|
manifest_enabled: bool = False,
|
||||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||||
|
oauth_config=None,
|
||||||
) -> AsyncConnection:
|
) -> AsyncConnection:
|
||||||
"""Connect to a LanceDB database.
|
"""Connect to a LanceDB database.
|
||||||
|
|
||||||
@@ -389,6 +392,10 @@ async def connect_async(
|
|||||||
namespace_client_properties : dict, optional
|
namespace_client_properties : dict, optional
|
||||||
Additional directory namespace client properties to use with
|
Additional directory namespace client properties to use with
|
||||||
``manifest_enabled=True``.
|
``manifest_enabled=True``.
|
||||||
|
oauth_config : OAuthConfig, optional
|
||||||
|
OAuth configuration for LanceDB Cloud/Enterprise. This is supported by
|
||||||
|
``connect_async`` only; synchronous ``connect`` uses API key
|
||||||
|
authentication for ``db://`` URIs.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -435,6 +442,7 @@ async def connect_async(
|
|||||||
session,
|
session,
|
||||||
manifest_enabled,
|
manifest_enabled,
|
||||||
namespace_client_properties,
|
namespace_client_properties,
|
||||||
|
oauth_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -280,6 +280,7 @@ async def connect(
|
|||||||
session: Optional[Session],
|
session: Optional[Session],
|
||||||
manifest_enabled: bool = False,
|
manifest_enabled: bool = False,
|
||||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||||
|
oauth_config: Optional[Any] = None,
|
||||||
) -> Connection: ...
|
) -> Connection: ...
|
||||||
|
|
||||||
class RecordBatchStream:
|
class RecordBatchStream:
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
self.pooling_strategy = None
|
self.pooling_strategy = None
|
||||||
|
|
||||||
|
|||||||
@@ -373,6 +373,19 @@ def _convert_pyarrow_schema_to_json(schema: pa.Schema) -> JsonArrowSchema:
|
|||||||
return JsonArrowSchema(fields=fields, metadata=meta)
|
return JsonArrowSchema(fields=fields, metadata=meta)
|
||||||
|
|
||||||
|
|
||||||
|
def _builds_namespace_natively(
|
||||||
|
namespace_client_impl: Optional[str],
|
||||||
|
namespace_client_properties: Optional[Dict[str, str]],
|
||||||
|
) -> bool:
|
||||||
|
"""Whether ``connect_namespace_client`` builds the namespace client natively
|
||||||
|
in Rust (installing the read-freshness context provider) rather than wrapping
|
||||||
|
the pre-built Python client.
|
||||||
|
|
||||||
|
Must mirror Rust ``build_namespace_natively`` in ``python/src/connection.rs``.
|
||||||
|
"""
|
||||||
|
return namespace_client_impl == "rest" and bool(namespace_client_properties)
|
||||||
|
|
||||||
|
|
||||||
class LanceNamespaceDBConnection(DBConnection):
|
class LanceNamespaceDBConnection(DBConnection):
|
||||||
"""
|
"""
|
||||||
A LanceDB connection that uses a namespace for table management.
|
A LanceDB connection that uses a namespace for table management.
|
||||||
@@ -432,6 +445,13 @@ class LanceNamespaceDBConnection(DBConnection):
|
|||||||
)
|
)
|
||||||
self._namespace_client_impl = namespace_client_impl
|
self._namespace_client_impl = namespace_client_impl
|
||||||
self._namespace_client_properties = namespace_client_properties
|
self._namespace_client_properties = namespace_client_properties
|
||||||
|
# When the namespace client is built natively (see Rust
|
||||||
|
# ``build_namespace_natively``), the underlying Rust table performs
|
||||||
|
# QueryTable pushdown through the read-freshness context provider, which
|
||||||
|
# the pure-Python ``query_table`` path bypasses.
|
||||||
|
self._route_pushdown_to_rust = _builds_namespace_natively(
|
||||||
|
namespace_client_impl, namespace_client_properties
|
||||||
|
)
|
||||||
self._inner = AsyncConnection(
|
self._inner = AsyncConnection(
|
||||||
_connect_namespace_client(
|
_connect_namespace_client(
|
||||||
namespace_client,
|
namespace_client,
|
||||||
@@ -543,6 +563,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
|||||||
namespace_path=namespace_path,
|
namespace_path=namespace_path,
|
||||||
namespace_client=self._namespace_client,
|
namespace_client=self._namespace_client,
|
||||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||||
|
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||||
_async=async_table,
|
_async=async_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -580,6 +601,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
|||||||
namespace_path=namespace_path,
|
namespace_path=namespace_path,
|
||||||
namespace_client=self._namespace_client,
|
namespace_client=self._namespace_client,
|
||||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||||
|
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||||
_async=async_table,
|
_async=async_table,
|
||||||
)
|
)
|
||||||
if branch is not None:
|
if branch is not None:
|
||||||
@@ -875,6 +897,8 @@ class AsyncLanceNamespaceDBConnection:
|
|||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
session: Optional[Session] = None,
|
session: Optional[Session] = None,
|
||||||
namespace_client_pushdown_operations: Optional[List[str]] = None,
|
namespace_client_pushdown_operations: Optional[List[str]] = None,
|
||||||
|
namespace_client_impl: Optional[str] = None,
|
||||||
|
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize an async namespace-based LanceDB connection.
|
Initialize an async namespace-based LanceDB connection.
|
||||||
@@ -900,6 +924,12 @@ class AsyncLanceNamespaceDBConnection:
|
|||||||
namespace.create_table() instead of using declare_table + local write.
|
namespace.create_table() instead of using declare_table + local write.
|
||||||
|
|
||||||
Default is None (no pushdown, all operations run locally).
|
Default is None (no pushdown, all operations run locally).
|
||||||
|
namespace_client_impl : Optional[str]
|
||||||
|
The namespace implementation name used to create this connection.
|
||||||
|
Required (with ``namespace_client_properties``) for the Rust client to
|
||||||
|
be built natively and install the read-freshness provider.
|
||||||
|
namespace_client_properties : Optional[Dict[str, str]]
|
||||||
|
The namespace properties used to create this connection.
|
||||||
"""
|
"""
|
||||||
self._namespace_client = namespace_client
|
self._namespace_client = namespace_client
|
||||||
self.read_consistency_interval = read_consistency_interval
|
self.read_consistency_interval = read_consistency_interval
|
||||||
@@ -908,6 +938,14 @@ class AsyncLanceNamespaceDBConnection:
|
|||||||
self._namespace_client_pushdown_operations = set(
|
self._namespace_client_pushdown_operations = set(
|
||||||
namespace_client_pushdown_operations or []
|
namespace_client_pushdown_operations or []
|
||||||
)
|
)
|
||||||
|
self._namespace_client_impl = namespace_client_impl
|
||||||
|
self._namespace_client_properties = namespace_client_properties
|
||||||
|
# See LanceNamespaceDBConnection: when built natively the Rust table runs
|
||||||
|
# QueryTable pushdown through the read-freshness provider, so defer to it
|
||||||
|
# rather than the urllib3 client (which omits x-lancedb-min-timestamp).
|
||||||
|
self._route_pushdown_to_rust = _builds_namespace_natively(
|
||||||
|
namespace_client_impl, namespace_client_properties
|
||||||
|
)
|
||||||
self._inner = AsyncConnection(
|
self._inner = AsyncConnection(
|
||||||
_connect_namespace_client(
|
_connect_namespace_client(
|
||||||
namespace_client,
|
namespace_client,
|
||||||
@@ -921,8 +959,8 @@ class AsyncLanceNamespaceDBConnection:
|
|||||||
namespace_client_pushdown_operations=(
|
namespace_client_pushdown_operations=(
|
||||||
list(self._namespace_client_pushdown_operations)
|
list(self._namespace_client_pushdown_operations)
|
||||||
),
|
),
|
||||||
namespace_client_impl=None,
|
namespace_client_impl=namespace_client_impl,
|
||||||
namespace_client_properties=None,
|
namespace_client_properties=namespace_client_properties,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -992,6 +1030,7 @@ class AsyncLanceNamespaceDBConnection:
|
|||||||
namespace_path=namespace_path,
|
namespace_path=namespace_path,
|
||||||
namespace_client=self._namespace_client,
|
namespace_client=self._namespace_client,
|
||||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||||
|
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def open_table(
|
async def open_table(
|
||||||
@@ -1029,6 +1068,7 @@ class AsyncLanceNamespaceDBConnection:
|
|||||||
namespace_path=namespace_path,
|
namespace_path=namespace_path,
|
||||||
namespace_client=self._namespace_client,
|
namespace_client=self._namespace_client,
|
||||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||||
|
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
|
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
|
||||||
@@ -1387,4 +1427,6 @@ def connect_namespace_async(
|
|||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
session=session,
|
session=session,
|
||||||
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
|
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
|
||||||
|
namespace_client_impl=namespace_client_impl,
|
||||||
|
namespace_client_properties=namespace_client_properties,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import List, Optional
|
|||||||
from lancedb import __version__
|
from lancedb import __version__
|
||||||
|
|
||||||
from .header import HeaderProvider
|
from .header import HeaderProvider
|
||||||
|
from .oauth import OAuthConfig, OAuthFlowType
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TimeoutConfig",
|
"TimeoutConfig",
|
||||||
@@ -16,6 +17,8 @@ __all__ = [
|
|||||||
"TlsConfig",
|
"TlsConfig",
|
||||||
"ClientConfig",
|
"ClientConfig",
|
||||||
"HeaderProvider",
|
"HeaderProvider",
|
||||||
|
"OAuthConfig",
|
||||||
|
"OAuthFlowType",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"request_thread_pool is no longer used and will be removed in "
|
"request_thread_pool is no longer used and will be removed in "
|
||||||
"a future release.",
|
"a future release.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if connection_timeout is not None:
|
if connection_timeout is not None:
|
||||||
@@ -132,6 +133,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"release. Please use client_config.timeout_config.connect_timeout "
|
"release. Please use client_config.timeout_config.connect_timeout "
|
||||||
"instead.",
|
"instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
client_config.timeout_config.connect_timeout = timedelta(
|
client_config.timeout_config.connect_timeout = timedelta(
|
||||||
seconds=connection_timeout
|
seconds=connection_timeout
|
||||||
@@ -142,6 +144,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"read_timeout is deprecated and will be removed in a future release. "
|
"read_timeout is deprecated and will be removed in a future release. "
|
||||||
"Please use client_config.timeout_config.read_timeout instead.",
|
"Please use client_config.timeout_config.read_timeout instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
|
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
|
||||||
|
|
||||||
|
|||||||
75
python/python/lancedb/remote/oauth.py
Normal file
75
python/python/lancedb/remote/oauth.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlowType(str, Enum):
|
||||||
|
"""OAuth authentication flow types."""
|
||||||
|
|
||||||
|
CLIENT_CREDENTIALS = "client_credentials"
|
||||||
|
"""Client Credentials grant (service-to-service / M2M)."""
|
||||||
|
|
||||||
|
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
|
||||||
|
"""Azure Managed Identity via IMDS."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthConfig:
|
||||||
|
"""OAuth configuration for LanceDB authentication.
|
||||||
|
|
||||||
|
All token acquisition and refresh is handled in the Rust layer.
|
||||||
|
This config is passed through to Rust via PyO3.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
issuer_url : str
|
||||||
|
OIDC issuer URL or OAuth authority URL.
|
||||||
|
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
|
||||||
|
client_id : str
|
||||||
|
Application / Client ID.
|
||||||
|
scopes : List[str]
|
||||||
|
OAuth scopes to request.
|
||||||
|
For Azure managed identity, exactly one scope or resource is required.
|
||||||
|
For example: ``["api://{app_id}/.default"]``
|
||||||
|
flow : OAuthFlowType
|
||||||
|
Authentication flow to use. Default: CLIENT_CREDENTIALS.
|
||||||
|
client_secret : Optional[str]
|
||||||
|
Client secret (required for CLIENT_CREDENTIALS).
|
||||||
|
managed_identity_client_id : Optional[str]
|
||||||
|
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
|
||||||
|
refresh_buffer_secs : Optional[int]
|
||||||
|
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||||
|
Keep this well below the token TTL; if it is greater than or equal to
|
||||||
|
the TTL, each request refreshes the token.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
Client Credentials (service-to-service):
|
||||||
|
|
||||||
|
>>> config = OAuthConfig(
|
||||||
|
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||||
|
... client_id="app-id",
|
||||||
|
... client_secret="secret",
|
||||||
|
... scopes=["api://lancedb-api/.default"],
|
||||||
|
... )
|
||||||
|
|
||||||
|
Azure Managed Identity:
|
||||||
|
|
||||||
|
>>> config = OAuthConfig(
|
||||||
|
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||||
|
... client_id="app-id",
|
||||||
|
... scopes=["api://lancedb-api/.default"],
|
||||||
|
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
|
issuer_url: str
|
||||||
|
client_id: str
|
||||||
|
scopes: List[str]
|
||||||
|
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
|
||||||
|
client_secret: Optional[str] = field(default=None, repr=False)
|
||||||
|
managed_identity_client_id: Optional[str] = None
|
||||||
|
refresh_buffer_secs: Optional[int] = None
|
||||||
@@ -845,7 +845,8 @@ class RemoteTable(Table):
|
|||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
|
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
|
||||||
"Tables are automatically cleaned up and optimized."
|
"Tables are automatically cleaned up and optimized.",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -857,7 +858,8 @@ class RemoteTable(Table):
|
|||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"compact_files() is a no-op on LanceDB Cloud. "
|
"compact_files() is a no-op on LanceDB Cloud. "
|
||||||
"Tables are automatically compacted and optimized."
|
"Tables are automatically compacted and optimized.",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -874,7 +876,8 @@ class RemoteTable(Table):
|
|||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"optimize() is a no-op on LanceDB Cloud. "
|
"optimize() is a no-op on LanceDB Cloud. "
|
||||||
"Indices are optimized automatically."
|
"Indices are optimized automatically.",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -2022,6 +2022,7 @@ class LanceTable(Table):
|
|||||||
namespace_client: Optional[Any] = None,
|
namespace_client: Optional[Any] = None,
|
||||||
managed_versioning: Optional[bool] = None,
|
managed_versioning: Optional[bool] = None,
|
||||||
pushdown_operations: Optional[set] = None,
|
pushdown_operations: Optional[set] = None,
|
||||||
|
route_pushdown_to_rust: bool = False,
|
||||||
_async: AsyncTable = None,
|
_async: AsyncTable = None,
|
||||||
):
|
):
|
||||||
if namespace_path is None:
|
if namespace_path is None:
|
||||||
@@ -2031,6 +2032,14 @@ class LanceTable(Table):
|
|||||||
self._location = location # Store location for use in _dataset_path
|
self._location = location # Store location for use in _dataset_path
|
||||||
self._namespace_client = namespace_client
|
self._namespace_client = namespace_client
|
||||||
self._pushdown_operations = pushdown_operations or set()
|
self._pushdown_operations = pushdown_operations or set()
|
||||||
|
# When the connection built the namespace client natively (e.g. an
|
||||||
|
# enterprise "rest" connection), the underlying Rust table already
|
||||||
|
# executes QueryTable pushdown itself -- and, unlike this Python urllib3
|
||||||
|
# path, it routes through the read-freshness context provider that emits
|
||||||
|
# the ``x-lancedb-min-timestamp`` header. So we must defer pushdown to
|
||||||
|
# Rust instead of calling the Python ``namespace_client.query_table``
|
||||||
|
# directly, or reads silently bypass read-freshness (stale results).
|
||||||
|
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||||
if _async is not None:
|
if _async is not None:
|
||||||
self._table = _async
|
self._table = _async
|
||||||
else:
|
else:
|
||||||
@@ -2241,6 +2250,7 @@ class LanceTable(Table):
|
|||||||
namespace_path=self._namespace_path,
|
namespace_path=self._namespace_path,
|
||||||
namespace_client=self._namespace_client,
|
namespace_client=self._namespace_client,
|
||||||
pushdown_operations=self._pushdown_operations,
|
pushdown_operations=self._pushdown_operations,
|
||||||
|
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||||
location=self._location,
|
location=self._location,
|
||||||
_async=async_table,
|
_async=async_table,
|
||||||
)
|
)
|
||||||
@@ -2391,8 +2401,11 @@ class LanceTable(Table):
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
pa.Table"""
|
pa.Table"""
|
||||||
if _should_push_down_query_table(
|
if (
|
||||||
self._namespace_client, self._pushdown_operations
|
_should_push_down_query_table(
|
||||||
|
self._namespace_client, self._pushdown_operations
|
||||||
|
)
|
||||||
|
and not self._route_pushdown_to_rust
|
||||||
):
|
):
|
||||||
return self._execute_query(Query()).read_all()
|
return self._execute_query(Query()).read_all()
|
||||||
|
|
||||||
@@ -3344,6 +3357,7 @@ class LanceTable(Table):
|
|||||||
location: Optional[str] = None,
|
location: Optional[str] = None,
|
||||||
namespace_client: Optional[Any] = None,
|
namespace_client: Optional[Any] = None,
|
||||||
pushdown_operations: Optional[set] = None,
|
pushdown_operations: Optional[set] = None,
|
||||||
|
route_pushdown_to_rust: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new table.
|
Create a new table.
|
||||||
@@ -3406,21 +3420,24 @@ class LanceTable(Table):
|
|||||||
self._location = location
|
self._location = location
|
||||||
self._namespace_client = namespace_client
|
self._namespace_client = namespace_client
|
||||||
self._pushdown_operations = pushdown_operations or set()
|
self._pushdown_operations = pushdown_operations or set()
|
||||||
|
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||||
|
|
||||||
if data_storage_version is not None:
|
if data_storage_version is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"setting data_storage_version directly on create_table is deprecated. ",
|
"setting data_storage_version directly on create_table is deprecated. "
|
||||||
"Use database_options instead.",
|
"Use database_options instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
if storage_options is None:
|
if storage_options is None:
|
||||||
storage_options = {}
|
storage_options = {}
|
||||||
storage_options["new_table_data_storage_version"] = data_storage_version
|
storage_options["new_table_data_storage_version"] = data_storage_version
|
||||||
if enable_v2_manifest_paths is not None:
|
if enable_v2_manifest_paths is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"setting enable_v2_manifest_paths directly on create_table is ",
|
"setting enable_v2_manifest_paths directly on create_table is "
|
||||||
"deprecated. Use database_options instead.",
|
"deprecated. Use database_options instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
if storage_options is None:
|
if storage_options is None:
|
||||||
storage_options = {}
|
storage_options = {}
|
||||||
@@ -3517,6 +3534,7 @@ class LanceTable(Table):
|
|||||||
_should_push_down_query_table(
|
_should_push_down_query_table(
|
||||||
self._namespace_client, self._pushdown_operations
|
self._namespace_client, self._pushdown_operations
|
||||||
)
|
)
|
||||||
|
and not self._route_pushdown_to_rust
|
||||||
and self.current_branch() is None
|
and self.current_branch() is None
|
||||||
):
|
):
|
||||||
from lancedb.namespace import _execute_server_side_query
|
from lancedb.namespace import _execute_server_side_query
|
||||||
@@ -4258,6 +4276,7 @@ class AsyncTable:
|
|||||||
namespace_path: Optional[List[str]] = None,
|
namespace_path: Optional[List[str]] = None,
|
||||||
namespace_client: Optional[Any] = None,
|
namespace_client: Optional[Any] = None,
|
||||||
pushdown_operations: Optional[set] = None,
|
pushdown_operations: Optional[set] = None,
|
||||||
|
route_pushdown_to_rust: bool = False,
|
||||||
):
|
):
|
||||||
"""Create a new AsyncTable object.
|
"""Create a new AsyncTable object.
|
||||||
|
|
||||||
@@ -4270,6 +4289,9 @@ class AsyncTable:
|
|||||||
self._namespace_path = namespace_path or []
|
self._namespace_path = namespace_path or []
|
||||||
self._namespace_client = namespace_client
|
self._namespace_client = namespace_client
|
||||||
self._pushdown_operations = pushdown_operations or set()
|
self._pushdown_operations = pushdown_operations or set()
|
||||||
|
# See LanceTable.__init__: defer QueryTable pushdown to Rust (which emits
|
||||||
|
# the read-freshness header) for natively-built namespace clients.
|
||||||
|
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||||
|
|
||||||
def _set_namespace_context(
|
def _set_namespace_context(
|
||||||
self,
|
self,
|
||||||
@@ -4277,10 +4299,12 @@ class AsyncTable:
|
|||||||
namespace_path: Optional[List[str]] = None,
|
namespace_path: Optional[List[str]] = None,
|
||||||
namespace_client: Optional[Any] = None,
|
namespace_client: Optional[Any] = None,
|
||||||
pushdown_operations: Optional[set] = None,
|
pushdown_operations: Optional[set] = None,
|
||||||
|
route_pushdown_to_rust: bool = False,
|
||||||
) -> "AsyncTable":
|
) -> "AsyncTable":
|
||||||
self._namespace_path = namespace_path or []
|
self._namespace_path = namespace_path or []
|
||||||
self._namespace_client = namespace_client
|
self._namespace_client = namespace_client
|
||||||
self._pushdown_operations = pushdown_operations or set()
|
self._pushdown_operations = pushdown_operations or set()
|
||||||
|
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -4490,8 +4514,11 @@ class AsyncTable:
|
|||||||
-------
|
-------
|
||||||
pa.Table
|
pa.Table
|
||||||
"""
|
"""
|
||||||
if _should_push_down_query_table(
|
if (
|
||||||
self._namespace_client, self._pushdown_operations
|
_should_push_down_query_table(
|
||||||
|
self._namespace_client, self._pushdown_operations
|
||||||
|
)
|
||||||
|
and not self._route_pushdown_to_rust
|
||||||
):
|
):
|
||||||
return (await self._execute_query(Query())).read_all()
|
return (await self._execute_query(Query())).read_all()
|
||||||
|
|
||||||
@@ -5175,8 +5202,11 @@ class AsyncTable:
|
|||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
timeout: Optional[timedelta] = None,
|
timeout: Optional[timedelta] = None,
|
||||||
) -> pa.RecordBatchReader:
|
) -> pa.RecordBatchReader:
|
||||||
if _should_push_down_query_table(
|
if (
|
||||||
self._namespace_client, self._pushdown_operations
|
_should_push_down_query_table(
|
||||||
|
self._namespace_client, self._pushdown_operations
|
||||||
|
)
|
||||||
|
and not self._route_pushdown_to_rust
|
||||||
):
|
):
|
||||||
from lancedb.namespace import _execute_server_side_query
|
from lancedb.namespace import _execute_server_side_query
|
||||||
|
|
||||||
@@ -5662,6 +5692,7 @@ class AsyncTable:
|
|||||||
"The 'retrain' parameter is deprecated and will be removed in a "
|
"The 'retrain' parameter is deprecated and will be removed in a "
|
||||||
"future version.",
|
"future version.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._inner.optimize(
|
return await self._inner.optimize(
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ def _namespace_lance_table(namespace_client: _NamespaceClient) -> LanceTable:
|
|||||||
table._namespace_path = ["geneva"]
|
table._namespace_path = ["geneva"]
|
||||||
table._namespace_client = namespace_client
|
table._namespace_client = namespace_client
|
||||||
table._pushdown_operations = {"QueryTable"}
|
table._pushdown_operations = {"QueryTable"}
|
||||||
|
# This test exercises the Python-side pushdown path (non-native client), so
|
||||||
|
# pushdown is not routed to Rust.
|
||||||
|
table._route_pushdown_to_rust = False
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
@@ -805,6 +808,37 @@ class TestPushdownOperations:
|
|||||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||||
assert len(db._namespace_client_pushdown_operations) == 0
|
assert len(db._namespace_client_pushdown_operations) == 0
|
||||||
|
|
||||||
|
def test_route_pushdown_to_rust_for_native_rest(self):
|
||||||
|
"""A natively-built rest connection must defer QueryTable pushdown to
|
||||||
|
Rust so reads carry the x-lancedb-min-timestamp read-freshness header."""
|
||||||
|
db = lancedb.connect_namespace(
|
||||||
|
"rest",
|
||||||
|
{"uri": "http://localhost:12345"},
|
||||||
|
namespace_client_pushdown_operations=["QueryTable"],
|
||||||
|
)
|
||||||
|
assert db._route_pushdown_to_rust is True
|
||||||
|
|
||||||
|
def test_route_pushdown_to_rust_false_for_dir(self):
|
||||||
|
"""A non-native (dir) connection keeps the Python pushdown path."""
|
||||||
|
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||||
|
assert db._route_pushdown_to_rust is False
|
||||||
|
|
||||||
|
def test_async_route_pushdown_to_rust_for_native_rest(self):
|
||||||
|
"""The async connection must not silently bypass the read-freshness fix:
|
||||||
|
a natively-built rest connection defers pushdown to Rust (regression test
|
||||||
|
for the async path omitting the freshness header)."""
|
||||||
|
db = lancedb.connect_namespace_async(
|
||||||
|
"rest",
|
||||||
|
{"uri": "http://localhost:12345"},
|
||||||
|
namespace_client_pushdown_operations=["QueryTable"],
|
||||||
|
)
|
||||||
|
assert db._route_pushdown_to_rust is True
|
||||||
|
|
||||||
|
def test_async_route_pushdown_to_rust_false_for_dir(self):
|
||||||
|
"""The async non-native (dir) connection keeps the Python pushdown path."""
|
||||||
|
db = lancedb.connect_namespace_async("dir", {"root": self.temp_dir})
|
||||||
|
assert db._route_pushdown_to_rust is False
|
||||||
|
|
||||||
def test_lance_table_to_arrow_uses_query_pushdown(self):
|
def test_lance_table_to_arrow_uses_query_pushdown(self):
|
||||||
namespace_client = _NamespaceClient()
|
namespace_client = _NamespaceClient()
|
||||||
table = _namespace_lance_table(namespace_client)
|
table = _namespace_lance_table(namespace_client)
|
||||||
|
|||||||
@@ -539,7 +539,7 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
|
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn connect(
|
pub fn connect(
|
||||||
py: Python<'_>,
|
py: Python<'_>,
|
||||||
@@ -553,6 +553,7 @@ pub fn connect(
|
|||||||
session: Option<crate::session::Session>,
|
session: Option<crate::session::Session>,
|
||||||
manifest_enabled: bool,
|
manifest_enabled: bool,
|
||||||
namespace_client_properties: Option<HashMap<String, String>>,
|
namespace_client_properties: Option<HashMap<String, String>>,
|
||||||
|
oauth_config: Option<crate::oauth::PyOAuthConfig>,
|
||||||
) -> PyResult<Bound<'_, PyAny>> {
|
) -> PyResult<Bound<'_, PyAny>> {
|
||||||
future_into_py(py, async move {
|
future_into_py(py, async move {
|
||||||
let mut builder = lancedb::connect(&uri);
|
let mut builder = lancedb::connect(&uri);
|
||||||
@@ -582,6 +583,11 @@ pub fn connect(
|
|||||||
if let Some(client_config) = client_config {
|
if let Some(client_config) = client_config {
|
||||||
builder = builder.client_config(client_config.into());
|
builder = builder.client_config(client_config.into());
|
||||||
}
|
}
|
||||||
|
if let Some(oauth_config) = oauth_config {
|
||||||
|
let config: lancedb::remote::oauth::OAuthConfig =
|
||||||
|
oauth_config.try_into().infer_error()?;
|
||||||
|
builder = builder.oauth_config(config);
|
||||||
|
}
|
||||||
if let Some(session) = session {
|
if let Some(session) = session {
|
||||||
builder = builder.session(session.inner.clone());
|
builder = builder.session(session.inner.clone());
|
||||||
}
|
}
|
||||||
@@ -610,24 +616,38 @@ pub fn connect_namespace_client(
|
|||||||
namespace_client_impl: Option<String>,
|
namespace_client_impl: Option<String>,
|
||||||
namespace_client_properties: Option<HashMap<String, String>>,
|
namespace_client_properties: Option<HashMap<String, String>>,
|
||||||
) -> PyResult<Connection> {
|
) -> PyResult<Connection> {
|
||||||
let namespace_client = extract_namespace_arc(py, namespace_client)?;
|
|
||||||
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
|
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
|
||||||
let namespace_client_pushdown_operations =
|
let namespace_client_pushdown_operations =
|
||||||
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
|
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
|
||||||
let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string());
|
|
||||||
let ns_properties = namespace_client_properties.unwrap_or_default();
|
let ns_properties = namespace_client_properties.unwrap_or_default();
|
||||||
let storage_options = storage_options.unwrap_or_default();
|
let storage_options = storage_options.unwrap_or_default();
|
||||||
let session = session.map(|s| s.inner.clone());
|
let session = session.map(|s| s.inner.clone());
|
||||||
|
|
||||||
let database = LanceNamespaceDatabase::from_namespace_client(
|
// Prefer building the namespace natively from (impl, properties) so the
|
||||||
namespace_client,
|
// read-freshness provider installed
|
||||||
ns_impl,
|
let database = if build_namespace_natively(namespace_client_impl.as_deref(), &ns_properties) {
|
||||||
ns_properties,
|
let ns_impl = namespace_client_impl.expect("impl present per build_namespace_natively");
|
||||||
storage_options,
|
crate::runtime::block_on(LanceNamespaceDatabase::connect(
|
||||||
read_consistency_interval,
|
&ns_impl,
|
||||||
session,
|
ns_properties,
|
||||||
namespace_client_pushdown_operations,
|
storage_options,
|
||||||
);
|
read_consistency_interval,
|
||||||
|
session,
|
||||||
|
namespace_client_pushdown_operations,
|
||||||
|
))
|
||||||
|
.infer_error()?
|
||||||
|
} else {
|
||||||
|
let namespace_client = extract_namespace_arc(py, namespace_client)?;
|
||||||
|
LanceNamespaceDatabase::from_namespace_client(
|
||||||
|
namespace_client,
|
||||||
|
namespace_client_impl.unwrap_or_else(|| "python".to_string()),
|
||||||
|
ns_properties,
|
||||||
|
storage_options,
|
||||||
|
read_consistency_interval,
|
||||||
|
session,
|
||||||
|
namespace_client_pushdown_operations,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Connection::new(LanceConnection::new(
|
Ok(Connection::new(LanceConnection::new(
|
||||||
Arc::new(database),
|
Arc::new(database),
|
||||||
@@ -635,6 +655,16 @@ pub fn connect_namespace_client(
|
|||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether to build the namespace natively (from impl + properties) instead of
|
||||||
|
/// wrapping a pre-built client. Native construction is required for the
|
||||||
|
/// read-freshness provider to be installed
|
||||||
|
fn build_namespace_natively(
|
||||||
|
namespace_client_impl: Option<&str>,
|
||||||
|
namespace_client_properties: &HashMap<String, String>,
|
||||||
|
) -> bool {
|
||||||
|
matches!(namespace_client_impl, Some("rest")) && !namespace_client_properties.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(FromPyObject)]
|
#[derive(FromPyObject)]
|
||||||
pub struct PyClientConfig {
|
pub struct PyClientConfig {
|
||||||
user_agent: String,
|
user_agent: String,
|
||||||
@@ -733,3 +763,36 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn props(pairs: &[(&str, &str)]) -> HashMap<String, String> {
|
||||||
|
pairs
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn native_build_only_for_rest_with_properties() {
|
||||||
|
let rest = props(&[("uri", "http://localhost:10024")]);
|
||||||
|
|
||||||
|
// rest + non-empty properties -> build natively (installs the
|
||||||
|
// read-freshness provider so checkout_latest() busts the server cache).
|
||||||
|
assert!(build_namespace_natively(Some("rest"), &rest));
|
||||||
|
|
||||||
|
// dir is local (no server cache) -> wrap the pre-built client unchanged.
|
||||||
|
assert!(!build_namespace_natively(
|
||||||
|
Some("dir"),
|
||||||
|
&props(&[("root", "/tmp")])
|
||||||
|
));
|
||||||
|
|
||||||
|
// No impl: only a pre-built client was handed in -> wrap it as-is.
|
||||||
|
assert!(!build_namespace_natively(None, &rest));
|
||||||
|
|
||||||
|
// rest but no properties: nothing to build a connection from -> wrap.
|
||||||
|
assert!(!build_namespace_natively(Some("rest"), &HashMap::new()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ pub mod expr;
|
|||||||
pub mod header;
|
pub mod header;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
pub mod namespace;
|
pub mod namespace;
|
||||||
|
pub mod oauth;
|
||||||
pub mod permutation;
|
pub mod permutation;
|
||||||
pub mod query;
|
pub mod query;
|
||||||
pub mod runtime;
|
pub mod runtime;
|
||||||
|
|||||||
72
python/src/oauth.rs
Normal file
72
python/src/oauth.rs
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use pyo3::FromPyObject;
|
||||||
|
|
||||||
|
use lancedb::error::Error;
|
||||||
|
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
|
||||||
|
|
||||||
|
/// Python-side OAuth configuration, extracted via FromPyObject.
|
||||||
|
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
|
||||||
|
#[derive(FromPyObject)]
|
||||||
|
pub struct PyOAuthConfig {
|
||||||
|
pub issuer_url: String,
|
||||||
|
pub client_id: String,
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
pub flow: String,
|
||||||
|
pub client_secret: Option<String>,
|
||||||
|
pub managed_identity_client_id: Option<String>,
|
||||||
|
pub refresh_buffer_secs: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<PyOAuthConfig> for OAuthConfig {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn try_from(py: PyOAuthConfig) -> Result<Self, Self::Error> {
|
||||||
|
let flow = match py.flow.as_str() {
|
||||||
|
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||||
|
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||||
|
client_id: py.managed_identity_client_id,
|
||||||
|
},
|
||||||
|
other => {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: format!("Unknown OAuth flow type: {other}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
issuer_url: py.issuer_url,
|
||||||
|
client_id: py.client_id,
|
||||||
|
client_secret: py.client_secret,
|
||||||
|
scopes: py.scopes,
|
||||||
|
flow,
|
||||||
|
refresh_buffer_secs: py.refresh_buffer_secs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unknown_oauth_flow_returns_invalid_input() {
|
||||||
|
let config = PyOAuthConfig {
|
||||||
|
issuer_url: "https://issuer.example.com".to_string(),
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: "typo".to_string(),
|
||||||
|
client_secret: None,
|
||||||
|
managed_identity_client_id: None,
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let err = OAuthConfig::try_from(config).unwrap_err();
|
||||||
|
assert!(matches!(
|
||||||
|
err,
|
||||||
|
Error::InvalidInput { message }
|
||||||
|
if message == "Unknown OAuth flow type: typo"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,6 +56,15 @@ fn get_runtime() -> &'static runtime::Runtime {
|
|||||||
unsafe { &*new_ptr }
|
unsafe { &*new_ptr }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Block the current thread on a future using the shared runtime.
|
||||||
|
///
|
||||||
|
/// For sync `#[pyfunction]`s that need to drive an async operation (e.g.
|
||||||
|
/// building a namespace client). Must not be called from within the runtime's
|
||||||
|
/// own worker threads.
|
||||||
|
pub fn block_on<F: std::future::Future>(fut: F) -> F::Output {
|
||||||
|
get_runtime().block_on(fut)
|
||||||
|
}
|
||||||
|
|
||||||
/// Runs in async-signal context after `fork()` in the child. We can only
|
/// Runs in async-signal context after `fork()` in the child. We can only
|
||||||
/// touch atomics here; we deliberately leak the previous runtime because
|
/// touch atomics here; we deliberately leak the previous runtime because
|
||||||
/// dropping a tokio `Runtime` would try to join its (now-dead) worker
|
/// dropping a tokio `Runtime` would try to join its (now-dead) worker
|
||||||
|
|||||||
33
python/tests/test_oauth.py
Normal file
33
python/tests/test_oauth.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _load_oauth_module():
|
||||||
|
oauth_path = (
|
||||||
|
Path(__file__).parents[1] / "python" / "lancedb" / "remote" / "oauth.py"
|
||||||
|
)
|
||||||
|
spec = importlib.util.spec_from_file_location("lancedb_remote_oauth", oauth_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
sys.modules[spec.name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_config_repr_redacts_client_secret():
|
||||||
|
oauth = _load_oauth_module()
|
||||||
|
|
||||||
|
config = oauth.OAuthConfig(
|
||||||
|
issuer_url="https://issuer.example.com",
|
||||||
|
client_id="client-id",
|
||||||
|
scopes=["scope"],
|
||||||
|
client_secret="super-secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
rendered = repr(config)
|
||||||
|
assert "super-secret" not in rendered
|
||||||
|
assert "client_secret" not in rendered
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.31.0-beta.1"
|
version = "0.31.0-beta.3"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -50,7 +50,7 @@ lance-namespace = { workspace = true }
|
|||||||
lance-namespace-impls = { workspace = true }
|
lance-namespace-impls = { workspace = true }
|
||||||
moka = { workspace = true }
|
moka = { workspace = true }
|
||||||
pin-project = { workspace = true }
|
pin-project = { workspace = true }
|
||||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
tokio = { version = "1.23", features = ["rt-multi-thread", "sync"] }
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
bytes = "1"
|
bytes = "1"
|
||||||
@@ -75,6 +75,7 @@ reqwest = { version = "0.12.0", default-features = false, features = [
|
|||||||
"stream",
|
"stream",
|
||||||
], optional = true }
|
], optional = true }
|
||||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||||
|
urlencoding = { version = "2", optional = true }
|
||||||
uuid = { version = "1.7.0", features = ["v4", "v5"] }
|
uuid = { version = "1.7.0", features = ["v4", "v5"] }
|
||||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||||
@@ -93,6 +94,7 @@ semver = { workspace = true }
|
|||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
tempfile = "3.5.0"
|
tempfile = "3.5.0"
|
||||||
random_word = { version = "0.4.3", features = ["en"] }
|
random_word = { version = "0.4.3", features = ["en"] }
|
||||||
|
tokio = { version = "1.23", features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] }
|
||||||
uuid = { version = "1.7.0", features = ["v4"] }
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
walkdir = "2"
|
walkdir = "2"
|
||||||
aws-sdk-dynamodb = { version = "1.55.0" }
|
aws-sdk-dynamodb = { version = "1.55.0" }
|
||||||
@@ -129,7 +131,13 @@ huggingface = [
|
|||||||
"lance-namespace-impls/dir-huggingface",
|
"lance-namespace-impls/dir-huggingface",
|
||||||
]
|
]
|
||||||
dynamodb = ["lance/dynamodb", "aws"]
|
dynamodb = ["lance/dynamodb", "aws"]
|
||||||
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
remote = [
|
||||||
|
"dep:reqwest",
|
||||||
|
"dep:http",
|
||||||
|
"dep:urlencoding",
|
||||||
|
"lance-namespace-impls/rest",
|
||||||
|
"lance-namespace-impls/rest-adapter",
|
||||||
|
]
|
||||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||||
s3-test = []
|
s3-test = []
|
||||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||||
|
|||||||
@@ -576,6 +576,9 @@ impl Connection {
|
|||||||
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
|
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
|
||||||
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
|
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
|
||||||
/// For RemoteDatabase, it is the equivalent RestNamespace.
|
/// For RemoteDatabase, it is the equivalent RestNamespace.
|
||||||
|
///
|
||||||
|
/// Remote connections using dynamic headers forward them through the
|
||||||
|
/// namespace client's per-request context provider.
|
||||||
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
|
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
|
||||||
self.internal.namespace_client().await
|
self.internal.namespace_client().await
|
||||||
}
|
}
|
||||||
@@ -584,6 +587,9 @@ impl Connection {
|
|||||||
/// Returns (impl_type, properties) where:
|
/// Returns (impl_type, properties) where:
|
||||||
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||||
/// - properties: configuration properties for the namespace
|
/// - properties: configuration properties for the namespace
|
||||||
|
///
|
||||||
|
/// Remote connections using dynamic headers cannot be exported because the
|
||||||
|
/// namespace client config only carries static headers.
|
||||||
pub async fn namespace_client_config(
|
pub async fn namespace_client_config(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<(String, std::collections::HashMap<String, String>)> {
|
) -> Result<(String, std::collections::HashMap<String, String>)> {
|
||||||
@@ -661,6 +667,8 @@ pub struct ConnectRequest {
|
|||||||
pub struct ConnectBuilder {
|
pub struct ConnectBuilder {
|
||||||
request: ConnectRequest,
|
request: ConnectRequest,
|
||||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
oauth_config: Option<crate::remote::OAuthConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "remote")]
|
#[cfg(feature = "remote")]
|
||||||
@@ -682,6 +690,8 @@ impl ConnectBuilder {
|
|||||||
session: None,
|
session: None,
|
||||||
},
|
},
|
||||||
embedding_registry: None,
|
embedding_registry: None,
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
oauth_config: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -770,6 +780,19 @@ impl ConnectBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
|
||||||
|
///
|
||||||
|
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
|
||||||
|
/// from the given config and sets it as the header provider. OAuth cannot
|
||||||
|
/// be combined with an API key or another header provider.
|
||||||
|
///
|
||||||
|
/// Token acquisition and refresh are handled in Rust.
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
pub fn oauth_config(mut self, config: crate::remote::OAuthConfig) -> Self {
|
||||||
|
self.oauth_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||||
self.embedding_registry = Some(registry);
|
self.embedding_registry = Some(registry);
|
||||||
@@ -915,9 +938,40 @@ impl ConnectBuilder {
|
|||||||
let region = options.region.ok_or_else(|| Error::InvalidInput {
|
let region = options.region.ok_or_else(|| Error::InvalidInput {
|
||||||
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||||
})?;
|
})?;
|
||||||
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
|
let api_key = match (&self.oauth_config, &options.api_key) {
|
||||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
(Some(_), None) => String::new(),
|
||||||
})?;
|
(Some(_), Some(_)) => {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message:
|
||||||
|
"api_key and oauth_config cannot both be set when connecting to LanceDb Cloud"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(None, Some(key)) => key.clone(),
|
||||||
|
(None, None) => {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message:
|
||||||
|
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.oauth_config.is_some() && self.request.client_config.header_provider.is_some() {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message:
|
||||||
|
"oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut client_config = self.request.client_config;
|
||||||
|
|
||||||
|
if let Some(oauth_config) = self.oauth_config {
|
||||||
|
let provider = crate::remote::OAuthHeaderProvider::new(oauth_config)?;
|
||||||
|
client_config.header_provider =
|
||||||
|
Some(Arc::new(provider) as Arc<dyn crate::remote::HeaderProvider>);
|
||||||
|
}
|
||||||
|
|
||||||
let storage_options = StorageOptions(options.storage_options.clone());
|
let storage_options = StorageOptions(options.storage_options.clone());
|
||||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||||
@@ -925,7 +979,7 @@ impl ConnectBuilder {
|
|||||||
&api_key,
|
&api_key,
|
||||||
®ion,
|
®ion,
|
||||||
options.host_override,
|
options.host_override,
|
||||||
self.request.client_config,
|
client_config,
|
||||||
storage_options.into(),
|
storage_options.into(),
|
||||||
self.request.read_consistency_interval,
|
self.request.read_consistency_interval,
|
||||||
)?);
|
)?);
|
||||||
@@ -1234,6 +1288,83 @@ mod tests {
|
|||||||
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
|
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_connect_rejects_api_key_with_oauth_config() {
|
||||||
|
let oauth_config = crate::remote::OAuthConfig {
|
||||||
|
issuer_url: "https://issuer.example.com".to_string(),
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: Some("secret".to_string()),
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: crate::remote::OAuthFlow::ClientCredentials,
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||||
|
.region("us-east-1")
|
||||||
|
.api_key("my-api-key")
|
||||||
|
.oauth_config(oauth_config)
|
||||||
|
.execute()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(Error::InvalidInput { message })
|
||||||
|
if message
|
||||||
|
== "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud" =>
|
||||||
|
{}
|
||||||
|
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||||
|
Ok(_) => panic!("expected api_key and oauth_config to be rejected"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_connect_rejects_header_provider_with_oauth_config() {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct TestHeaderProvider;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl crate::remote::HeaderProvider for TestHeaderProvider {
|
||||||
|
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||||
|
Ok(HashMap::from([(
|
||||||
|
"authorization".to_string(),
|
||||||
|
"Bearer token".to_string(),
|
||||||
|
)]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let oauth_config = crate::remote::OAuthConfig {
|
||||||
|
issuer_url: "https://issuer.example.com".to_string(),
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: Some("secret".to_string()),
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: crate::remote::OAuthFlow::ClientCredentials,
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
let client_config = crate::remote::ClientConfig {
|
||||||
|
header_provider: Some(
|
||||||
|
Arc::new(TestHeaderProvider) as Arc<dyn crate::remote::HeaderProvider>
|
||||||
|
),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||||
|
.region("us-east-1")
|
||||||
|
.client_config(client_config)
|
||||||
|
.oauth_config(oauth_config)
|
||||||
|
.execute()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(Error::InvalidInput { message })
|
||||||
|
if message
|
||||||
|
== "oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud" =>
|
||||||
|
{}
|
||||||
|
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||||
|
Ok(_) => panic!("expected header_provider and oauth_config to be rejected"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(windows))]
|
#[cfg(not(windows))]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_connect_relative() {
|
async fn test_connect_relative() {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
pub(crate) mod client;
|
pub(crate) mod client;
|
||||||
pub(crate) mod db;
|
pub(crate) mod db;
|
||||||
|
pub mod oauth;
|
||||||
mod retry;
|
mod retry;
|
||||||
pub(crate) mod table;
|
pub(crate) mod table;
|
||||||
pub(crate) mod util;
|
pub(crate) mod util;
|
||||||
@@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json";
|
|||||||
|
|
||||||
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
||||||
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
||||||
|
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};
|
||||||
|
|||||||
@@ -459,12 +459,14 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
|||||||
config: &ClientConfig,
|
config: &ClientConfig,
|
||||||
) -> Result<HeaderMap> {
|
) -> Result<HeaderMap> {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(
|
if !api_key.is_empty() {
|
||||||
HeaderName::from_static("x-api-key"),
|
headers.insert(
|
||||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
HeaderName::from_static("x-api-key"),
|
||||||
message: "non-ascii api key provided".to_string(),
|
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||||
})?,
|
message: "non-ascii api key provided".to_string(),
|
||||||
);
|
})?,
|
||||||
|
);
|
||||||
|
}
|
||||||
if region == "local" {
|
if region == "local" {
|
||||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
@@ -1005,6 +1007,33 @@ mod tests {
|
|||||||
assert!(!config_tls.assert_hostname);
|
assert!(!config_tls.assert_hostname);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_headers_skip_empty_api_key() {
|
||||||
|
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||||
|
"",
|
||||||
|
"us-east-1",
|
||||||
|
"db-name",
|
||||||
|
false,
|
||||||
|
&RemoteOptions::default(),
|
||||||
|
None,
|
||||||
|
&ClientConfig::default(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(!headers.contains_key("x-api-key"));
|
||||||
|
|
||||||
|
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||||
|
"api-key",
|
||||||
|
"us-east-1",
|
||||||
|
"db-name",
|
||||||
|
false,
|
||||||
|
&RemoteOptions::default(),
|
||||||
|
None,
|
||||||
|
&ClientConfig::default(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(headers.get("x-api-key").unwrap(), "api-key");
|
||||||
|
}
|
||||||
|
|
||||||
// Test implementation of HeaderProvider
|
// Test implementation of HeaderProvider
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct TestHeaderProvider {
|
struct TestHeaderProvider {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use std::sync::Arc;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use lance_io::object_store::StorageOptions;
|
use lance_io::object_store::StorageOptions;
|
||||||
|
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||||
use moka::future::Cache;
|
use moka::future::Cache;
|
||||||
use reqwest::header::CONTENT_TYPE;
|
use reqwest::header::CONTENT_TYPE;
|
||||||
|
|
||||||
@@ -26,7 +27,9 @@ use crate::remote::util::stream_as_body;
|
|||||||
use crate::table::BaseTable;
|
use crate::table::BaseTable;
|
||||||
|
|
||||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
use super::client::{
|
||||||
|
ClientConfig, HeaderProvider, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender,
|
||||||
|
};
|
||||||
use super::table::RemoteTable;
|
use super::table::RemoteTable;
|
||||||
use super::util::parse_server_version;
|
use super::util::parse_server_version;
|
||||||
|
|
||||||
@@ -194,10 +197,66 @@ pub struct RemoteDatabase<S: HttpSend = Sender> {
|
|||||||
uri: String,
|
uri: String,
|
||||||
/// Headers to pass to the namespace client for authentication
|
/// Headers to pass to the namespace client for authentication
|
||||||
namespace_headers: HashMap<String, String>,
|
namespace_headers: HashMap<String, String>,
|
||||||
|
namespace_context_provider: Option<Arc<dyn DynamicContextProvider>>,
|
||||||
/// TLS configuration for mTLS support
|
/// TLS configuration for mTLS support
|
||||||
tls_config: Option<super::client::TlsConfig>,
|
tls_config: Option<super::client::TlsConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct NamespaceHeaderProviderContext {
|
||||||
|
header_provider: Arc<dyn HeaderProvider>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for NamespaceHeaderProviderContext {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("NamespaceHeaderProviderContext")
|
||||||
|
.field("header_provider", &"Some(...)")
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DynamicContextProvider for NamespaceHeaderProviderContext {
|
||||||
|
fn provide_context(&self, _info: &OperationInfo) -> HashMap<String, String> {
|
||||||
|
let header_provider = Arc::clone(&self.header_provider);
|
||||||
|
let handle = match std::thread::Builder::new()
|
||||||
|
.name("lancedb-namespace-headers".to_string())
|
||||||
|
.spawn(move || {
|
||||||
|
tokio::runtime::Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!(
|
||||||
|
"Failed to create runtime for namespace header provider: {e}"
|
||||||
|
),
|
||||||
|
})?
|
||||||
|
.block_on(header_provider.get_headers())
|
||||||
|
}) {
|
||||||
|
Ok(handle) => handle,
|
||||||
|
Err(err) => {
|
||||||
|
log::warn!("Failed to spawn dynamic namespace header provider thread: {err}");
|
||||||
|
return HashMap::new();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let headers = handle.join();
|
||||||
|
|
||||||
|
match headers {
|
||||||
|
Ok(Ok(headers)) => headers
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, value)| (format!("headers.{key}"), value))
|
||||||
|
.collect(),
|
||||||
|
Ok(Err(err)) => {
|
||||||
|
log::warn!("Failed to get dynamic namespace headers: {err}");
|
||||||
|
HashMap::new()
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
log::warn!("Dynamic namespace header provider panicked");
|
||||||
|
HashMap::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl RemoteDatabase {
|
impl RemoteDatabase {
|
||||||
pub fn try_new(
|
pub fn try_new(
|
||||||
uri: &str,
|
uri: &str,
|
||||||
@@ -228,6 +287,16 @@ impl RemoteDatabase {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let namespace_context_provider =
|
||||||
|
client_config
|
||||||
|
.header_provider
|
||||||
|
.as_ref()
|
||||||
|
.map(|header_provider| {
|
||||||
|
Arc::new(NamespaceHeaderProviderContext {
|
||||||
|
header_provider: Arc::clone(header_provider),
|
||||||
|
}) as Arc<dyn DynamicContextProvider>
|
||||||
|
});
|
||||||
|
|
||||||
let client = RestfulLanceDbClient::try_new(
|
let client = RestfulLanceDbClient::try_new(
|
||||||
&parsed,
|
&parsed,
|
||||||
region,
|
region,
|
||||||
@@ -247,6 +316,7 @@ impl RemoteDatabase {
|
|||||||
table_cache,
|
table_cache,
|
||||||
uri: uri.to_owned(),
|
uri: uri.to_owned(),
|
||||||
namespace_headers,
|
namespace_headers,
|
||||||
|
namespace_context_provider,
|
||||||
tls_config: client_config.tls_config,
|
tls_config: client_config.tls_config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -271,6 +341,7 @@ mod test_utils {
|
|||||||
table_cache: Cache::new(0),
|
table_cache: Cache::new(0),
|
||||||
uri: "http://localhost".to_string(),
|
uri: "http://localhost".to_string(),
|
||||||
namespace_headers: HashMap::new(),
|
namespace_headers: HashMap::new(),
|
||||||
|
namespace_context_provider: None,
|
||||||
tls_config: None,
|
tls_config: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -281,11 +352,18 @@ mod test_utils {
|
|||||||
T: Into<reqwest::Body>,
|
T: Into<reqwest::Body>,
|
||||||
{
|
{
|
||||||
let client = client_with_handler_and_config(handler, config.clone());
|
let client = client_with_handler_and_config(handler, config.clone());
|
||||||
|
let namespace_context_provider =
|
||||||
|
config.header_provider.as_ref().map(|header_provider| {
|
||||||
|
Arc::new(NamespaceHeaderProviderContext {
|
||||||
|
header_provider: Arc::clone(header_provider),
|
||||||
|
}) as Arc<dyn DynamicContextProvider>
|
||||||
|
});
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
table_cache: Cache::new(0),
|
table_cache: Cache::new(0),
|
||||||
uri: "http://localhost".to_string(),
|
uri: "http://localhost".to_string(),
|
||||||
namespace_headers: config.extra_headers.clone(),
|
namespace_headers: config.extra_headers.clone(),
|
||||||
|
namespace_context_provider,
|
||||||
tls_config: config.tls_config.clone(),
|
tls_config: config.tls_config.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -759,9 +837,12 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
|||||||
// Create a RestNamespace pointing to the same remote host with the same authentication headers
|
// Create a RestNamespace pointing to the same remote host with the same authentication headers
|
||||||
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
|
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
|
||||||
.delimiter(&self.client.id_delimiter)
|
.delimiter(&self.client.id_delimiter)
|
||||||
// TODO: support header provider
|
|
||||||
.headers(self.namespace_headers.clone());
|
.headers(self.namespace_headers.clone());
|
||||||
|
|
||||||
|
if let Some(context_provider) = &self.namespace_context_provider {
|
||||||
|
builder = builder.context_provider(Arc::clone(context_provider));
|
||||||
|
}
|
||||||
|
|
||||||
// Apply mTLS configuration if present
|
// Apply mTLS configuration if present
|
||||||
if let Some(tls_config) = &self.tls_config {
|
if let Some(tls_config) = &self.tls_config {
|
||||||
if let Some(cert_file) = &tls_config.cert_file {
|
if let Some(cert_file) = &tls_config.cert_file {
|
||||||
@@ -781,6 +862,14 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||||
|
if self.namespace_context_provider.is_some() {
|
||||||
|
return Err(Error::NotSupported {
|
||||||
|
message:
|
||||||
|
"Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let mut properties = HashMap::new();
|
let mut properties = HashMap::new();
|
||||||
properties.insert("uri".to_string(), self.client.host().to_string());
|
properties.insert("uri".to_string(), self.client.host().to_string());
|
||||||
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
|
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
|
||||||
@@ -832,12 +921,13 @@ impl From<StorageOptions> for RemoteOptions {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::build_cache_key;
|
use super::{NamespaceHeaderProviderContext, build_cache_key};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, OnceLock};
|
use std::sync::{Arc, OnceLock};
|
||||||
|
|
||||||
use arrow_array::{Int32Array, RecordBatch};
|
use arrow_array::{Int32Array, RecordBatch};
|
||||||
use arrow_schema::{DataType, Field, Schema};
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||||
|
|
||||||
use crate::connection::ConnectBuilder;
|
use crate::connection::ConnectBuilder;
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -1702,6 +1792,75 @@ mod tests {
|
|||||||
assert!(namespace_client.is_ok());
|
assert!(namespace_client.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_namespace_header_provider_context_maps_headers() {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct TestHeaderProvider;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl HeaderProvider for TestHeaderProvider {
|
||||||
|
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||||
|
Ok(HashMap::from([(
|
||||||
|
"authorization".to_string(),
|
||||||
|
"Bearer token".to_string(),
|
||||||
|
)]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let context_provider = NamespaceHeaderProviderContext {
|
||||||
|
header_provider: Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>,
|
||||||
|
};
|
||||||
|
|
||||||
|
let context =
|
||||||
|
context_provider.provide_context(&OperationInfo::new("list_tables", "namespace"));
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
context.get("headers.authorization"),
|
||||||
|
Some(&"Bearer token".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_namespace_client_supports_dynamic_headers() {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct TestHeaderProvider;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl HeaderProvider for TestHeaderProvider {
|
||||||
|
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||||
|
Ok(HashMap::from([(
|
||||||
|
"authorization".to_string(),
|
||||||
|
"Bearer token".to_string(),
|
||||||
|
)]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_config = ClientConfig {
|
||||||
|
header_provider: Some(Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let conn = Connection::new_with_handler_and_config(
|
||||||
|
|_| {
|
||||||
|
http::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body(r#"{"tables": []}"#)
|
||||||
|
.unwrap()
|
||||||
|
},
|
||||||
|
client_config,
|
||||||
|
);
|
||||||
|
|
||||||
|
let namespace_client = conn.namespace_client().await;
|
||||||
|
assert!(namespace_client.is_ok());
|
||||||
|
|
||||||
|
match conn.namespace_client_config().await {
|
||||||
|
Err(Error::NotSupported { message })
|
||||||
|
if message.contains("dynamic headers are configured") => {}
|
||||||
|
Err(err) => panic!("expected NotSupported, got {err:?}"),
|
||||||
|
Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
|
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
|
||||||
mod rest_adapter_integration {
|
mod rest_adapter_integration {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
907
rust/lancedb/src/remote/oauth.rs
Normal file
907
rust/lancedb/src/remote/oauth.rs
Normal file
@@ -0,0 +1,907 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use log::debug;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
use crate::remote::client::HeaderProvider;
|
||||||
|
|
||||||
|
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
|
||||||
|
const DEFAULT_TOKEN_TTL_SECS: u64 = 3600;
|
||||||
|
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
|
||||||
|
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
|
||||||
|
|
||||||
|
/// OAuth authentication flow configuration.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum OAuthFlow {
|
||||||
|
/// Client Credentials grant (service-to-service / M2M).
|
||||||
|
/// Requires `client_secret` in [`OAuthConfig`].
|
||||||
|
ClientCredentials,
|
||||||
|
|
||||||
|
/// Azure Managed Identity via IMDS.
|
||||||
|
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
|
||||||
|
/// IMDS requests bypass proxy settings because the endpoint is link-local.
|
||||||
|
AzureManagedIdentity {
|
||||||
|
/// Client ID for user-assigned managed identity.
|
||||||
|
/// Omit for system-assigned managed identity.
|
||||||
|
client_id: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OAuth configuration for LanceDB authentication.
|
||||||
|
///
|
||||||
|
/// All token acquisition and refresh is handled in the Rust layer.
|
||||||
|
/// Python and TypeScript bindings expose this as a plain config object.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OAuthConfig {
|
||||||
|
/// OIDC issuer URL or OAuth authority URL.
|
||||||
|
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||||
|
pub issuer_url: String,
|
||||||
|
|
||||||
|
/// Application / Client ID.
|
||||||
|
pub client_id: String,
|
||||||
|
|
||||||
|
/// Client secret (required for `ClientCredentials`, optional for others).
|
||||||
|
pub client_secret: Option<String>,
|
||||||
|
|
||||||
|
/// OAuth scopes to request.
|
||||||
|
/// For Azure managed identity, exactly one scope or resource is required.
|
||||||
|
/// For example: `["api://{app_id}/.default"]`
|
||||||
|
pub scopes: Vec<String>,
|
||||||
|
|
||||||
|
/// Authentication flow to use.
|
||||||
|
pub flow: OAuthFlow,
|
||||||
|
|
||||||
|
/// Seconds before token expiry to trigger proactive refresh (default: 300).
|
||||||
|
/// Keep this well below the token TTL; if it is greater than or equal to
|
||||||
|
/// the TTL, each request refreshes the token.
|
||||||
|
pub refresh_buffer_secs: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for OAuthConfig {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("OAuthConfig")
|
||||||
|
.field("issuer_url", &self.issuer_url)
|
||||||
|
.field("client_id", &self.client_id)
|
||||||
|
.field(
|
||||||
|
"client_secret",
|
||||||
|
&self.client_secret.as_deref().map(|_| "<redacted>"),
|
||||||
|
)
|
||||||
|
.field("scopes", &self.scopes)
|
||||||
|
.field("flow", &self.flow)
|
||||||
|
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- OIDC Discovery --
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
struct OidcDiscovery {
|
||||||
|
token_endpoint: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- Token Response --
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct TokenResponse {
|
||||||
|
access_token: String,
|
||||||
|
/// Token lifetime in seconds.
|
||||||
|
/// Some providers (Azure IMDS) return this as a string, so we accept both.
|
||||||
|
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
|
||||||
|
expires_in: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
token_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for TokenResponse {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("TokenResponse")
|
||||||
|
.field("access_token", &"<redacted>")
|
||||||
|
.field("expires_in", &self.expires_in)
|
||||||
|
.field("token_type", &self.token_type)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_optional_u64_or_string<'de, D>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<Option<u64>, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
use serde::de;
|
||||||
|
|
||||||
|
struct U64OrString;
|
||||||
|
impl<'de> de::Visitor<'de> for U64OrString {
|
||||||
|
type Value = Option<u64>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
formatter.write_str("an integer, an integer-valued float, a numeric string, or null")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
|
||||||
|
Ok(Some(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
|
||||||
|
if v < 0 {
|
||||||
|
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||||
|
}
|
||||||
|
Ok(Some(v as u64))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
|
||||||
|
if !v.is_finite() || v < 0.0 || v.fract() != 0.0 || v > u64::MAX as f64 {
|
||||||
|
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||||
|
}
|
||||||
|
Ok(Some(v as u64))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
|
||||||
|
v.parse::<u64>().map(Some).map_err(de::Error::custom)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializer.deserialize_any(U64OrString)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- Internal Token State --
|
||||||
|
|
||||||
|
struct TokenState {
|
||||||
|
access_token: Option<String>,
|
||||||
|
expires_at: Option<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenState {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
access_token: None,
|
||||||
|
expires_at: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_expired(&self, buffer: Duration) -> bool {
|
||||||
|
match (self.access_token.as_ref(), self.expires_at) {
|
||||||
|
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
|
||||||
|
(None, _) => true,
|
||||||
|
(Some(_), None) => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, resp: &TokenResponse) {
|
||||||
|
self.access_token = Some(resp.access_token.clone());
|
||||||
|
let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS);
|
||||||
|
self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
trait TokenSource: Send + Sync + std::fmt::Debug {
|
||||||
|
async fn fetch_token(&self) -> Result<TokenResponse>;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ClientCredentialsSource {
|
||||||
|
issuer_url: String,
|
||||||
|
client_id: String,
|
||||||
|
client_secret: String,
|
||||||
|
scopes: Vec<String>,
|
||||||
|
http_client: Client,
|
||||||
|
discovery: RwLock<Option<OidcDiscovery>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ClientCredentialsSource {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ClientCredentialsSource")
|
||||||
|
.field("issuer_url", &self.issuer_url)
|
||||||
|
.field("client_id", &self.client_id)
|
||||||
|
.field("client_secret", &"<redacted>")
|
||||||
|
.field("scopes", &self.scopes)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientCredentialsSource {
|
||||||
|
fn new(
|
||||||
|
issuer_url: String,
|
||||||
|
client_id: String,
|
||||||
|
client_secret: Option<String>,
|
||||||
|
scopes: Vec<String>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let client_secret = client_secret.ok_or(Error::InvalidInput {
|
||||||
|
message: "client_secret is required for ClientCredentials flow".to_string(),
|
||||||
|
})?;
|
||||||
|
Self::validate_issuer_transport(&issuer_url)?;
|
||||||
|
|
||||||
|
let http_client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(30))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to create HTTP client for OAuth: {e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
issuer_url,
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
scopes,
|
||||||
|
http_client,
|
||||||
|
discovery: RwLock::new(None),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_issuer_transport(issuer_url: &str) -> Result<()> {
|
||||||
|
let issuer = url::Url::parse(issuer_url).map_err(|e| Error::InvalidInput {
|
||||||
|
message: format!("Invalid OAuth issuer_url: {e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
match issuer.scheme() {
|
||||||
|
"https" => Ok(()),
|
||||||
|
"http" if Self::is_loopback_issuer(&issuer) => Ok(()),
|
||||||
|
_ => Err(Error::InvalidInput {
|
||||||
|
message:
|
||||||
|
"ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||||
|
.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_loopback_issuer(issuer: &url::Url) -> bool {
|
||||||
|
let Some(host) = issuer.host_str() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
host.eq_ignore_ascii_case("localhost")
|
||||||
|
|| host
|
||||||
|
.parse::<IpAddr>()
|
||||||
|
.map(|addr| addr.is_loopback())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_discovery(&self) -> Result<OidcDiscovery> {
|
||||||
|
{
|
||||||
|
let cached = self.discovery.read().await;
|
||||||
|
if let Some(ref disc) = *cached {
|
||||||
|
return Ok(disc.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut cache = self.discovery.write().await;
|
||||||
|
// Double-check
|
||||||
|
if let Some(ref disc) = *cache {
|
||||||
|
return Ok(disc.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let discovery_url = format!(
|
||||||
|
"{}/.well-known/openid-configuration",
|
||||||
|
self.issuer_url.trim_end_matches('/')
|
||||||
|
);
|
||||||
|
|
||||||
|
debug!("Fetching OIDC discovery from {}", discovery_url);
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.http_client
|
||||||
|
.get(&discovery_url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to fetch OIDC discovery document: {e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(Error::Runtime {
|
||||||
|
message: format!(
|
||||||
|
"OIDC discovery failed with status {}: {}",
|
||||||
|
resp.status(),
|
||||||
|
resp.text().await.unwrap_or_default()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to parse OIDC discovery document: {e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let result = disc.clone();
|
||||||
|
|
||||||
|
*cache = Some(disc);
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_token_endpoint(&self) -> Result<String> {
|
||||||
|
self.get_discovery().await.map(|disc| disc.token_endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scopes_string(&self) -> String {
|
||||||
|
self.scopes.join(" ")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn post_token_request(
|
||||||
|
&self,
|
||||||
|
endpoint: &str,
|
||||||
|
params: &[(&str, &str)],
|
||||||
|
) -> Result<TokenResponse> {
|
||||||
|
let resp = self
|
||||||
|
.http_client
|
||||||
|
.post(endpoint)
|
||||||
|
.form(params)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Token request to {endpoint} failed: {e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(Error::Runtime {
|
||||||
|
message: format!(
|
||||||
|
"Token request failed with status {}: {}",
|
||||||
|
resp.status(),
|
||||||
|
resp.text().await.unwrap_or_default()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.json().await.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to parse token response: {e}"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TokenSource for ClientCredentialsSource {
|
||||||
|
async fn fetch_token(&self) -> Result<TokenResponse> {
|
||||||
|
let token_endpoint = self.get_token_endpoint().await?;
|
||||||
|
let scope = self.scopes_string();
|
||||||
|
let params = [
|
||||||
|
("grant_type", "client_credentials"),
|
||||||
|
("client_id", self.client_id.as_str()),
|
||||||
|
("client_secret", self.client_secret.as_str()),
|
||||||
|
("scope", scope.as_str()),
|
||||||
|
];
|
||||||
|
|
||||||
|
self.post_token_request(&token_endpoint, ¶ms).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AzureImdsSource {
|
||||||
|
client_id: Option<String>,
|
||||||
|
resource: String,
|
||||||
|
http_client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for AzureImdsSource {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("AzureImdsSource")
|
||||||
|
.field("client_id", &self.client_id)
|
||||||
|
.field("resource", &self.resource)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AzureImdsSource {
|
||||||
|
fn new(scopes: Vec<String>, client_id: Option<String>) -> Result<Self> {
|
||||||
|
let resource = Self::resource_from_scopes(&scopes)?;
|
||||||
|
let http_client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(30))
|
||||||
|
.no_proxy()
|
||||||
|
.build()
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to create HTTP client for Azure IMDS OAuth: {e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client_id,
|
||||||
|
resource,
|
||||||
|
http_client,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resource_from_scopes(scopes: &[String]) -> Result<String> {
|
||||||
|
let [scope] = scopes else {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
|
||||||
|
.to_string(),
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TokenSource for AzureImdsSource {
|
||||||
|
async fn fetch_token(&self) -> Result<TokenResponse> {
|
||||||
|
let mut url = format!(
|
||||||
|
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
|
||||||
|
urlencoding::encode(&self.resource),
|
||||||
|
);
|
||||||
|
if let Some(cid) = self.client_id.as_deref() {
|
||||||
|
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.http_client
|
||||||
|
.get(&url)
|
||||||
|
.header("Metadata", "true")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Azure IMDS request failed: {e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(Error::Runtime {
|
||||||
|
message: format!(
|
||||||
|
"Azure IMDS returned status {}: {}",
|
||||||
|
resp.status(),
|
||||||
|
resp.text().await.unwrap_or_default()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.json().await.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to parse IMDS token response: {e}"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OAuth header provider that manages the full token lifecycle.
|
||||||
|
///
|
||||||
|
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
|
||||||
|
/// headers into every LanceDB request, with automatic token refresh.
|
||||||
|
pub struct OAuthHeaderProvider {
|
||||||
|
token_source: Box<dyn TokenSource>,
|
||||||
|
token_state: Arc<RwLock<TokenState>>,
|
||||||
|
refresh_buffer: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for OAuthHeaderProvider {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("OAuthHeaderProvider")
|
||||||
|
.field("token_source", &self.token_source)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthHeaderProvider {
|
||||||
|
/// Create a new OAuth header provider from configuration.
|
||||||
|
pub fn new(config: OAuthConfig) -> Result<Self> {
|
||||||
|
let OAuthConfig {
|
||||||
|
issuer_url,
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
scopes,
|
||||||
|
flow,
|
||||||
|
refresh_buffer_secs,
|
||||||
|
} = config;
|
||||||
|
|
||||||
|
if scopes.is_empty() {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: "At least one OAuth scope is required".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let refresh_buffer =
|
||||||
|
Duration::from_secs(refresh_buffer_secs.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS));
|
||||||
|
let token_source: Box<dyn TokenSource> = match flow {
|
||||||
|
OAuthFlow::ClientCredentials => Box::new(ClientCredentialsSource::new(
|
||||||
|
issuer_url,
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
scopes,
|
||||||
|
)?),
|
||||||
|
OAuthFlow::AzureManagedIdentity { client_id } => {
|
||||||
|
Box::new(AzureImdsSource::new(scopes, client_id)?)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
token_source,
|
||||||
|
token_state: Arc::new(RwLock::new(TokenState::new())),
|
||||||
|
refresh_buffer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a valid access token, refreshing if necessary.
|
||||||
|
async fn get_valid_token(&self) -> Result<String> {
|
||||||
|
// Fast path: check if current token is still valid
|
||||||
|
{
|
||||||
|
let state = self.token_state.read().await;
|
||||||
|
if !state.is_expired(self.refresh_buffer)
|
||||||
|
&& let Some(ref token) = state.access_token
|
||||||
|
{
|
||||||
|
return Ok(token.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slow path: acquire or refresh token
|
||||||
|
let mut state = self.token_state.write().await;
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if !state.is_expired(self.refresh_buffer)
|
||||||
|
&& let Some(ref token) = state.access_token
|
||||||
|
{
|
||||||
|
return Ok(token.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!("Acquiring new OAuth token via {:?}", self.token_source);
|
||||||
|
let resp = self.token_source.fetch_token().await?;
|
||||||
|
|
||||||
|
state.update(&resp);
|
||||||
|
Ok(resp.access_token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl HeaderProvider for OAuthHeaderProvider {
|
||||||
|
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||||
|
let token = self.get_valid_token().await?;
|
||||||
|
Ok(HashMap::from([(
|
||||||
|
"authorization".to_string(),
|
||||||
|
format!("Bearer {token}"),
|
||||||
|
)]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_token_state_expiry() {
|
||||||
|
let mut state = TokenState::new();
|
||||||
|
assert!(state.is_expired(Duration::from_secs(0)));
|
||||||
|
|
||||||
|
state.access_token = Some("tok".to_string());
|
||||||
|
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
|
||||||
|
assert!(!state.is_expired(Duration::from_secs(300)));
|
||||||
|
assert!(state.is_expired(Duration::from_secs(601)));
|
||||||
|
|
||||||
|
state.expires_at = None;
|
||||||
|
assert!(state.is_expired(Duration::from_secs(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_token_state_uses_default_expiry() {
|
||||||
|
let mut state = TokenState::new();
|
||||||
|
let response = TokenResponse {
|
||||||
|
access_token: "tok".to_string(),
|
||||||
|
expires_in: None,
|
||||||
|
token_type: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
state.update(&response);
|
||||||
|
|
||||||
|
assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1)));
|
||||||
|
assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_token_response_accepts_float_expires_in() {
|
||||||
|
let response: TokenResponse =
|
||||||
|
serde_json::from_str(r#"{"access_token":"tok","expires_in":3600.0}"#).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(response.expires_in, Some(3600));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_token_response_rejects_negative_expires_in() {
|
||||||
|
let err =
|
||||||
|
serde_json::from_str::<TokenResponse>(r#"{"access_token":"tok","expires_in":-1}"#)
|
||||||
|
.unwrap_err();
|
||||||
|
|
||||||
|
assert!(err.to_string().contains("invalid expires_in value: -1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_token_response_debug_redacts_access_token() {
|
||||||
|
let response = TokenResponse {
|
||||||
|
access_token: "secret-token".to_string(),
|
||||||
|
expires_in: Some(3600),
|
||||||
|
token_type: Some("Bearer".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let debug = format!("{response:?}");
|
||||||
|
assert!(!debug.contains("secret-token"));
|
||||||
|
assert!(debug.contains("access_token: \"<redacted>\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_scopes_string() {
|
||||||
|
let source = ClientCredentialsSource::new(
|
||||||
|
"https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||||
|
"app-id".to_string(),
|
||||||
|
Some("secret".to_string()),
|
||||||
|
vec!["scope1".to_string(), "scope2".to_string()],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(source.scopes_string(), "scope1 scope2");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_oauth_config_debug_redacts_client_secret() {
|
||||||
|
let config = OAuthConfig {
|
||||||
|
issuer_url: "https://issuer.example.com".to_string(),
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: Some("super-secret".to_string()),
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: OAuthFlow::ClientCredentials,
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let debug = format!("{config:?}");
|
||||||
|
assert!(!debug.contains("super-secret"));
|
||||||
|
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_oauth_header_provider_debug_redacts_client_secret() {
|
||||||
|
let config = OAuthConfig {
|
||||||
|
issuer_url: "https://issuer.example.com".to_string(),
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: Some("super-secret".to_string()),
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: OAuthFlow::ClientCredentials,
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||||
|
let debug = format!("{provider:?}");
|
||||||
|
assert!(!debug.contains("super-secret"));
|
||||||
|
assert!(debug.contains("client_secret: \"<redacted>\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_managed_identity_resource_from_default_scope() {
|
||||||
|
assert_eq!(
|
||||||
|
AzureImdsSource::resource_from_scopes(&["api://test/.default".to_string()]).unwrap(),
|
||||||
|
"api://test"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_managed_identity_resource_without_default_suffix() {
|
||||||
|
assert_eq!(
|
||||||
|
AzureImdsSource::resource_from_scopes(&["api://test".to_string()]).unwrap(),
|
||||||
|
"api://test"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_managed_identity_rejects_multiple_scopes() {
|
||||||
|
let config = OAuthConfig {
|
||||||
|
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||||
|
client_id: "app-id".to_string(),
|
||||||
|
client_secret: None,
|
||||||
|
scopes: vec![
|
||||||
|
"api://test-a/.default".to_string(),
|
||||||
|
"api://test-b/.default".to_string(),
|
||||||
|
],
|
||||||
|
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_token_endpoint_requires_discovery_success() {
|
||||||
|
let (issuer_url, server) = spawn_discovery_error_server().await;
|
||||||
|
let source = ClientCredentialsSource::new(
|
||||||
|
issuer_url,
|
||||||
|
"client-id".to_string(),
|
||||||
|
Some("secret".to_string()),
|
||||||
|
vec!["scope".to_string()],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let err = source.get_token_endpoint().await.unwrap_err();
|
||||||
|
assert!(matches!(
|
||||||
|
err,
|
||||||
|
Error::Runtime { message }
|
||||||
|
if message.contains("OIDC discovery failed with status 503")
|
||||||
|
));
|
||||||
|
server.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_credentials_requires_secret() {
|
||||||
|
let config = OAuthConfig {
|
||||||
|
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||||
|
client_id: "app-id".to_string(),
|
||||||
|
client_secret: None,
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: OAuthFlow::ClientCredentials,
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_credentials_rejects_insecure_non_loopback_issuer() {
|
||||||
|
let config = OAuthConfig {
|
||||||
|
issuer_url: "http://issuer.example.com".to_string(),
|
||||||
|
client_id: "app-id".to_string(),
|
||||||
|
client_secret: Some("secret".to_string()),
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: OAuthFlow::ClientCredentials,
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let err = OAuthHeaderProvider::new(config).unwrap_err();
|
||||||
|
assert!(matches!(
|
||||||
|
err,
|
||||||
|
Error::InvalidInput { message }
|
||||||
|
if message == "ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_scopes_rejected() {
|
||||||
|
let config = OAuthConfig {
|
||||||
|
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||||
|
client_id: "app-id".to_string(),
|
||||||
|
client_secret: None,
|
||||||
|
scopes: vec![],
|
||||||
|
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||||
|
refresh_buffer_secs: None,
|
||||||
|
};
|
||||||
|
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_credentials_token_lifecycle() {
|
||||||
|
let (issuer_url, token_requests, server) = spawn_oauth_server().await;
|
||||||
|
let config = OAuthConfig {
|
||||||
|
issuer_url,
|
||||||
|
client_id: "client-id".to_string(),
|
||||||
|
client_secret: Some("secret".to_string()),
|
||||||
|
scopes: vec!["scope".to_string()],
|
||||||
|
flow: OAuthFlow::ClientCredentials,
|
||||||
|
refresh_buffer_secs: Some(0),
|
||||||
|
};
|
||||||
|
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||||
|
|
||||||
|
let headers = provider.get_headers().await.unwrap();
|
||||||
|
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||||
|
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||||
|
|
||||||
|
let headers = provider.get_headers().await.unwrap();
|
||||||
|
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||||
|
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||||
|
|
||||||
|
provider.token_state.write().await.expires_at =
|
||||||
|
Some(Instant::now() - Duration::from_secs(1));
|
||||||
|
|
||||||
|
let headers = provider.get_headers().await.unwrap();
|
||||||
|
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2");
|
||||||
|
assert_eq!(token_requests.load(Ordering::SeqCst), 2);
|
||||||
|
|
||||||
|
server.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn spawn_oauth_server() -> (String, Arc<AtomicUsize>, JoinHandle<()>) {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let issuer_url = format!("http://{addr}");
|
||||||
|
let token_requests = Arc::new(AtomicUsize::new(0));
|
||||||
|
let server_token_requests = Arc::clone(&token_requests);
|
||||||
|
|
||||||
|
let server = tokio::spawn(async move {
|
||||||
|
for _ in 0..3 {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let (request_line, body) = read_http_request(&mut stream).await;
|
||||||
|
|
||||||
|
if request_line.starts_with("GET /.well-known/openid-configuration ") {
|
||||||
|
let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#);
|
||||||
|
write_json_response(&mut stream, "200 OK", &discovery).await;
|
||||||
|
} else if request_line.starts_with("POST /token ") {
|
||||||
|
assert!(body.contains("grant_type=client_credentials"));
|
||||||
|
assert!(body.contains("client_id=client-id"));
|
||||||
|
assert!(body.contains("client_secret=secret"));
|
||||||
|
assert!(body.contains("scope=scope"));
|
||||||
|
|
||||||
|
let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
let token = format!(
|
||||||
|
r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"#
|
||||||
|
);
|
||||||
|
write_json_response(&mut stream, "200 OK", &token).await;
|
||||||
|
} else {
|
||||||
|
write_json_response(&mut stream, "404 Not Found", "{}").await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
(issuer_url, token_requests, server)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn spawn_discovery_error_server() -> (String, JoinHandle<()>) {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let issuer_url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let server = tokio::spawn(async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let (request_line, _) = read_http_request(&mut stream).await;
|
||||||
|
assert!(request_line.starts_with("GET /.well-known/openid-configuration "));
|
||||||
|
write_json_response(&mut stream, "503 Service Unavailable", "{}").await;
|
||||||
|
});
|
||||||
|
|
||||||
|
(issuer_url, server)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_http_request(stream: &mut TcpStream) -> (String, String) {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
let mut header_end = None;
|
||||||
|
|
||||||
|
while header_end.is_none() {
|
||||||
|
let mut chunk = [0; 1024];
|
||||||
|
let read = stream.read(&mut chunk).await.unwrap();
|
||||||
|
assert_ne!(read, 0, "connection closed before request headers");
|
||||||
|
buffer.extend_from_slice(&chunk[..read]);
|
||||||
|
header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
let header_end = header_end.unwrap();
|
||||||
|
let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
|
||||||
|
let request_line = headers.lines().next().unwrap_or_default().to_string();
|
||||||
|
let content_length = headers
|
||||||
|
.lines()
|
||||||
|
.find_map(|line| {
|
||||||
|
let (name, value) = line.split_once(':')?;
|
||||||
|
name.eq_ignore_ascii_case("content-length")
|
||||||
|
.then(|| value.trim().parse::<usize>().ok())
|
||||||
|
.flatten()
|
||||||
|
})
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
while buffer.len() < header_end + content_length {
|
||||||
|
let mut chunk = [0; 1024];
|
||||||
|
let read = stream.read(&mut chunk).await.unwrap();
|
||||||
|
assert_ne!(read, 0, "connection closed before request body");
|
||||||
|
buffer.extend_from_slice(&chunk[..read]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let body =
|
||||||
|
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
|
||||||
|
|
||||||
|
(request_line, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||||
|
haystack
|
||||||
|
.windows(needle.len())
|
||||||
|
.position(|window| window == needle)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) {
|
||||||
|
let response = format!(
|
||||||
|
"HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||||
|
body.len()
|
||||||
|
);
|
||||||
|
stream.write_all(response.as_bytes()).await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -579,24 +579,45 @@ fn array_to_f32_vec(arr: &Arc<dyn arrow_array::Array>) -> Result<Vec<f32>> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Magic bytes that prefix (and suffix) the Arrow IPC *file* format.
|
||||||
|
const ARROW_IPC_FILE_MAGIC: &[u8] = b"ARROW1";
|
||||||
|
|
||||||
/// Parse Arrow IPC response from the namespace server.
|
/// Parse Arrow IPC response from the namespace server.
|
||||||
|
///
|
||||||
|
/// The server may return either the Arrow IPC *file* format or the *stream*
|
||||||
|
/// format. REST/phalanx returns the file format (it begins with the `ARROW1`
|
||||||
|
/// magic); reading that with a `StreamReader` fails with "failed to fill whole
|
||||||
|
/// buffer". Detect the magic and pick the matching reader so both are handled.
|
||||||
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
|
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
|
||||||
use arrow_ipc::reader::StreamReader;
|
use arrow_ipc::reader::{FileReader, StreamReader};
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
|
||||||
let cursor = Cursor::new(bytes);
|
let (schema, batches) = if bytes.starts_with(ARROW_IPC_FILE_MAGIC) {
|
||||||
let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime {
|
let reader = FileReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
|
||||||
message: format!("Failed to parse Arrow IPC response: {}", e),
|
message: format!("Failed to parse Arrow IPC file response: {}", e),
|
||||||
})?;
|
|
||||||
|
|
||||||
// Collect all record batches
|
|
||||||
let schema = reader.schema();
|
|
||||||
let batches: Vec<_> = reader
|
|
||||||
.into_iter()
|
|
||||||
.collect::<std::result::Result<Vec<_>, _>>()
|
|
||||||
.map_err(|e| Error::Runtime {
|
|
||||||
message: format!("Failed to read Arrow IPC batches: {}", e),
|
|
||||||
})?;
|
})?;
|
||||||
|
let schema = reader.schema();
|
||||||
|
let batches = reader
|
||||||
|
.into_iter()
|
||||||
|
.collect::<std::result::Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to read Arrow IPC file batches: {}", e),
|
||||||
|
})?;
|
||||||
|
(schema, batches)
|
||||||
|
} else {
|
||||||
|
let reader =
|
||||||
|
StreamReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to parse Arrow IPC response: {}", e),
|
||||||
|
})?;
|
||||||
|
let schema = reader.schema();
|
||||||
|
let batches = reader
|
||||||
|
.into_iter()
|
||||||
|
.collect::<std::result::Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to read Arrow IPC batches: {}", e),
|
||||||
|
})?;
|
||||||
|
(schema, batches)
|
||||||
|
};
|
||||||
|
|
||||||
// Create a stream from the batches
|
// Create a stream from the batches
|
||||||
let stream = futures::stream::iter(batches.into_iter().map(Ok));
|
let stream = futures::stream::iter(batches.into_iter().map(Ok));
|
||||||
@@ -624,6 +645,59 @@ mod tests {
|
|||||||
FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap()
|
FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_parse_arrow_ipc_response_handles_file_and_stream() {
|
||||||
|
use arrow_array::{Int32Array, RecordBatch};
|
||||||
|
use arrow_ipc::writer::{FileWriter, StreamWriter};
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
|
||||||
|
let batch = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Arrow IPC *file* format -- what REST/phalanx returns. Previously this
|
||||||
|
// failed with "failed to fill whole buffer" because we used a StreamReader.
|
||||||
|
let mut file_buf = Vec::new();
|
||||||
|
{
|
||||||
|
let mut writer = FileWriter::try_new(&mut file_buf, &schema).unwrap();
|
||||||
|
writer.write(&batch).unwrap();
|
||||||
|
writer.finish().unwrap();
|
||||||
|
}
|
||||||
|
assert!(file_buf.starts_with(ARROW_IPC_FILE_MAGIC));
|
||||||
|
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(file_buf))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.map(|b| b.num_rows())
|
||||||
|
.sum();
|
||||||
|
assert_eq!(rows, 3);
|
||||||
|
|
||||||
|
// Arrow IPC *stream* format must still parse.
|
||||||
|
let mut stream_buf = Vec::new();
|
||||||
|
{
|
||||||
|
let mut writer = StreamWriter::try_new(&mut stream_buf, &schema).unwrap();
|
||||||
|
writer.write(&batch).unwrap();
|
||||||
|
writer.finish().unwrap();
|
||||||
|
}
|
||||||
|
assert!(!stream_buf.starts_with(ARROW_IPC_FILE_MAGIC));
|
||||||
|
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(stream_buf))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.map(|b| b.num_rows())
|
||||||
|
.sum();
|
||||||
|
assert_eq!(rows, 3);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_convert_to_namespace_query_vector() {
|
fn test_convert_to_namespace_query_vector() {
|
||||||
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));
|
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));
|
||||||
|
|||||||
Reference in New Issue
Block a user