Compare commits

..

12 Commits

Author SHA1 Message Date
Jack Ye
d96c90c5b9 docs(node): update OAuth config docs 2026-06-27 00:35:55 -07:00
Jack Ye
c1a8702c65 feat(node): expose OAuth connection config 2026-06-27 00:02:59 -07:00
Jack Ye
3df3043563 feat(rust): add OAuth header provider (#3579)
## Summary

Add the Rust OAuth header provider for remote LanceDB connections.

This supports client credentials and Azure managed identity flows,
handles token caching and refresh, redacts secrets in Debug output, and
wires `ConnectBuilder::oauth_config()` into the remote client while
rejecting ambiguous API-key/header-provider combinations.
2026-06-26 23:57:16 -07:00
Ryan Green
8a5cd74e48 fix: ensure read freshness provider is built into namespace client (#3571)
By default the read freshness provider was not included in the namespace
client, preventing the read freshness headers from being included in the
request. This prevents checkout_latest() from working as expected when
using the namespace client.

This fix ensures the provided is built into the client when the
namespace impl and properties are provided.
2026-06-25 21:47:55 -07:00
Lance Release
448d5ec20f Bump version: 0.31.0-beta.2 → 0.31.0-beta.3 2026-06-25 01:55:06 +00:00
Lance Release
8718345229 Bump version: 0.34.0-beta.2 → 0.34.0-beta.3 2026-06-25 01:53:51 +00:00
LanceDB Robot
026fedc286 chore: update lance dependency to v9.0.0-beta.8 (#3580)
Updates Lance dependencies from v9.0.0-beta.4 to v9.0.0-beta.8.\n\nThis
refreshes the Rust workspace lockfile and the Java lance-core version.
Triggering Lance tag:
https://github.com/lance-format/lance/releases/tag/v9.0.0-beta.8
2026-06-24 18:52:59 -07:00
Jack Ye
fe287dc98c fix(remote): support namespace clients with dynamic headers
Bridge LanceDB dynamic header providers into Lance Namespace dynamic context providers for live remote namespace clients.
2026-06-24 15:30:00 -07:00
Jack Ye
411568b72c fix(remote): omit empty api key header (#3573)
## Summary

Skip inserting the x-api-key header when the configured API key is
empty.

This lets bearer-token or other dynamic-header authentication avoid
sending an empty static API key header alongside the real auth header.
2026-06-24 13:25:59 -07:00
LanceDB Robot
ebf8d55ede chore: update lance dependency to v9.0.0-beta.4 (#3570)
Bumps the Lance dependencies to v9.0.0-beta.4 and refreshes the
generated lockfile metadata. No compatibility fixes were required beyond
the dependency updates. Triggered by
https://github.com/lance-format/lance/releases/tag/v9.0.0-beta.4
2026-06-24 10:16:29 -05:00
Raphael Malikian
0ba70d96c3 fix: add missing stacklevel=2 to warnings.warn() and fix broken message concatenation (Fixes #3563) (#3564)
Fixes #3563

## Summary

- Add `stacklevel=2` to 10 `warnings.warn()` calls across 4 files
- Fix broken message concatenation in `table.py` where the second string
was incorrectly passed as the `category` parameter

## Problem

Multiple `warnings.warn()` calls in the `python/lancedb/` codebase were
missing the `stacklevel` parameter. Without `stacklevel=2`, warnings
point to library internals instead of the caller's code, making it
impossible for users to identify which of their function calls triggered
the warning.

Additionally, two calls in `table.py` (lines 3411 and 3420) had a more
serious bug: the deprecation message was split across two separate
string arguments, causing the second string to be passed as the
`category` parameter instead of being concatenated with the first
string. This would cause `TypeError` when the warning was triggered.

## Changes

| File | Fixes | Description |
|------|-------|-------------|
| `embeddings/colpali.py` | 1 | Add `stacklevel=2` to
`use_token_pooling` deprecation warning |
| `remote/db.py` | 3 | Add `stacklevel=2` to `request_thread_pool`,
`connection_timeout`, `read_timeout` deprecation warnings |
| `remote/table.py` | 3 | Add `stacklevel=2` to `cleanup_old_versions`,
`compact_files`, `optimize` no-op warnings |
| `table.py` | 3 | Fix broken message concatenation for
`data_storage_version` and `enable_v2_manifest_paths` deprecation
warnings + add `stacklevel=2` to `retrain` deprecation warning |

## Verification

```python
# All warnings.warn() calls now have stacklevel
python3 -c "import ast, os; ..."
# Result: All warnings.warn() calls now have stacklevel!
```

## Changelog

| Date | Change | Author |
|------|--------|--------|
| 2026-06-20 | Fix missing stacklevel=2 in 10 warnings.warn() calls +
fix broken message concatenation | rtmalikian |

### Files Changed
- `python/python/lancedb/embeddings/colpali.py` — Add stacklevel=2
- `python/python/lancedb/remote/db.py` — Add stacklevel=2 to 3
deprecation warnings
- `python/python/lancedb/remote/table.py` — Add stacklevel=2 to 3 no-op
warnings
- `python/python/lancedb/table.py` — Fix broken message concatenation +
add stacklevel=2

### Verification
- AST-based audit confirms all `warnings.warn()` calls now include
`stacklevel=2`
- Syntax check passes for all 4 modified files

---

**About the Author:** Raphael Malikian — Clinical AI Solutions
Architect. I specialise in building and fixing AI/ML systems for
healthcare, including vector databases, RAG pipelines, and clinical NLP.
If you need help with your project or think I can add value to your
organisation, feel free to reach out — I'd love to connect.

📧 rtmalikian@gmail.com
🔗 GitHub: https://github.com/rtmalikian
🔗 LinkedIn:
http://www.linkedin.com/in/raphael-t-malikian-mbbs-bsc-hons-71075436a

---

**Disclosure:** This code was developed with assistance from **Hermes
Agent** (Nous Research). All changes were reviewed, tested against the
actual codebase, and verified for correctness.

Signed-off-by: rtmalikian <rtmalikian@gmail.com>
2026-06-23 13:42:59 -07:00
Lance Release
0749532c3c Bump version: 0.31.0-beta.1 → 0.31.0-beta.2 2026-06-23 16:23:08 +00:00
43 changed files with 2079 additions and 133 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.31.0-beta.1"
current_version = "0.31.0-beta.3"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

97
Cargo.lock generated
View File

@@ -3432,8 +3432,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"rand 0.9.4",
@@ -4735,8 +4735,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
[[package]]
name = "lance"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arc-swap",
"arrow",
@@ -4771,7 +4771,7 @@ dependencies = [
"futures",
"half",
"humantime",
"itertools 0.13.0",
"itertools 0.14.0",
"lance-arrow",
"lance-core",
"lance-datafusion",
@@ -4810,8 +4810,8 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4832,7 +4832,7 @@ dependencies = [
[[package]]
name = "lance-arrow-scalar"
version = "58.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4846,7 +4846,7 @@ dependencies = [
[[package]]
name = "lance-arrow-stats"
version = "58.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -4855,8 +4855,8 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrayref",
"paste",
@@ -4865,8 +4865,8 @@ dependencies = [
[[package]]
name = "lance-core"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4878,7 +4878,7 @@ dependencies = [
"datafusion-common",
"datafusion-sql",
"futures",
"itertools 0.13.0",
"itertools 0.14.0",
"lance-arrow",
"lance-derive",
"libc",
@@ -4904,8 +4904,8 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow",
"arrow-array",
@@ -4935,8 +4935,8 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow",
"arrow-array",
@@ -4953,8 +4953,8 @@ dependencies = [
[[package]]
name = "lance-derive"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"proc-macro2",
"quote",
@@ -4963,8 +4963,8 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4980,7 +4980,7 @@ dependencies = [
"futures",
"hex",
"hyperloglogplus",
"itertools 0.13.0",
"itertools 0.14.0",
"lance-arrow",
"lance-bitpacking",
"lance-core",
@@ -4999,8 +4999,8 @@ dependencies = [
[[package]]
name = "lance-file"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -5030,8 +5030,8 @@ dependencies = [
[[package]]
name = "lance-index"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arc-swap",
"arrow",
@@ -5056,7 +5056,7 @@ dependencies = [
"fst",
"futures",
"half",
"itertools 0.13.0",
"itertools 0.14.0",
"jieba-rs",
"jsonb",
"lance-arrow",
@@ -5096,8 +5096,8 @@ dependencies = [
[[package]]
name = "lance-io"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow",
"arrow-arith",
@@ -5138,8 +5138,8 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5155,8 +5155,8 @@ dependencies = [
[[package]]
name = "lance-namespace"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow",
"async-trait",
@@ -5168,8 +5168,8 @@ dependencies = [
[[package]]
name = "lance-namespace-impls"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow",
"arrow-ipc",
@@ -5223,15 +5223,15 @@ dependencies = [
[[package]]
name = "lance-select"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"arrow-buffer",
"arrow-schema",
"byteorder",
"bytes",
"itertools 0.13.0",
"itertools 0.14.0",
"lance-core",
"roaring",
"tracing",
@@ -5239,8 +5239,8 @@ dependencies = [
[[package]]
name = "lance-table"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow",
"arrow-array",
@@ -5279,8 +5279,8 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -5293,8 +5293,8 @@ dependencies = [
[[package]]
name = "lance-tokenizer"
version = "9.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.2#23211989de648fefc4454f5eee09ec176f0a465b"
version = "9.0.0-beta.8"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.8#71c4aa2174971e98acb7e256fde1e1589024f5bc"
dependencies = [
"icu_segmenter",
"jieba-rs",
@@ -5307,7 +5307,7 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.31.0-beta.1"
version = "0.31.0-beta.3"
dependencies = [
"ahash",
"anyhow",
@@ -5384,13 +5384,14 @@ dependencies = [
"tokenizers",
"tokio",
"url",
"urlencoding",
"uuid",
"walkdir",
]
[[package]]
name = "lancedb-nodejs"
version = "0.31.0-beta.1"
version = "0.31.0-beta.3"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5415,7 +5416,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.34.0-beta.1"
version = "0.34.0-beta.3"
dependencies = [
"arrow",
"async-trait",

View File

@@ -13,20 +13,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=9.0.0-beta.2", default-features = false, "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=9.0.0-beta.2", "tag" = "v9.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=9.0.0-beta.8", default-features = false, "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=9.0.0-beta.8", default-features = false, "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=9.0.0-beta.8", default-features = false, "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=9.0.0-beta.8", "tag" = "v9.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "58.0.0", optional = false }

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.31.0-beta.1</version>
<version>0.31.0-beta.3</version>
</dependency>
```

View File

@@ -0,0 +1,29 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthFlowType
# Enumeration: OAuthFlowType
OAuth authentication flow types.
## Enumeration Members
### AzureManagedIdentity
```ts
AzureManagedIdentity: "azure_managed_identity";
```
Azure Managed Identity via IMDS.
***
### ClientCredentials
```ts
ClientCredentials: "client_credentials";
```
Client Credentials grant (service-to-service / M2M).

View File

@@ -12,6 +12,7 @@
## Enumerations
- [FullTextQueryType](enumerations/FullTextQueryType.md)
- [OAuthFlowType](enumerations/OAuthFlowType.md)
- [Occur](enumerations/Occur.md)
- [Operator](enumerations/Operator.md)
@@ -85,6 +86,8 @@
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
- [MergeResult](interfaces/MergeResult.md)
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
- [OAuthConfig](interfaces/OAuthConfig.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
- [OptimizeStats](interfaces/OptimizeStats.md)

View File

@@ -64,6 +64,19 @@ client used by manifest-enabled native connections.
***
### oauthConfig?
```ts
optional oauthConfig: NativeOAuthConfig;
```
(For LanceDB cloud only): OAuth configuration for IdP-based
authentication (e.g., Azure Entra ID). When set, token acquisition
and refresh are handled entirely in Rust. TypeScript users should pass
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
***
### readConsistencyInterval?
```ts

View File

@@ -0,0 +1,88 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / NativeOAuthConfig
# Interface: NativeOAuthConfig
OAuth configuration for LanceDB authentication.
This is the generated napi-rs binding shape. TypeScript users should prefer
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
All token acquisition and refresh is handled in the Rust layer.
## Properties
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for client_credentials).
***
### flow?
```ts
optional flow: string;
```
Authentication flow: "client_credentials" or "azure_managed_identity"
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (azure_managed_identity).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
Keep this well below the token TTL; if it is greater than or equal to
the TTL, each request refreshes the token.
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request. For Azure managed identity, exactly one scope
or resource is required. For example: `["api://{app_id}/.default"]`

View File

@@ -0,0 +1,111 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthConfig
# Interface: OAuthConfig
OAuth configuration for LanceDB authentication.
This is the public TypeScript OAuth configuration type. The generated
`NativeOAuthConfig` type has the same runtime shape but is an implementation
detail of the napi-rs binding.
All token acquisition and refresh is handled in the Rust layer.
This config is passed through to Rust via napi-rs.
## Examples
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
clientSecret: "secret",
scopes: ["api://lancedb-api/.default"],
};
```
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
scopes: ["api://lancedb-api/.default"],
flow: OAuthFlowType.AzureManagedIdentity,
};
```
## Properties
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for ClientCredentials).
***
### flow?
```ts
optional flow: OAuthFlowType;
```
Authentication flow (default: ClientCredentials).
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (AzureManagedIdentity).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
Keep this well below the token TTL; if it is greater than or equal to
the TTL, each request refreshes the token.
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request.
For Azure managed identity, exactly one scope or resource is required.
For example: `["api://{app_id}/.default"]`

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.31.0-beta.1</version>
<version>0.31.0-beta.3</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.31.0-beta.1</version>
<version>0.31.0-beta.3</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>9.0.0-beta.2</lance-core.version>
<lance-core.version>9.0.0-beta.8</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.31.0-beta.1"
version = "0.31.0-beta.3"
publish = false
license.workspace = true
description.workspace = true

View File

@@ -52,6 +52,7 @@ export {
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
OAuthConfig as NativeOAuthConfig,
} from "./native.js";
export {
@@ -130,6 +131,8 @@ export {
TokenResponse,
} from "./header";
export { OAuthConfig, OAuthFlowType } from "./oauth";
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";

76
nodejs/lancedb/oauth.ts Normal file
View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
/**
* OAuth authentication flow types.
*/
export enum OAuthFlowType {
/** Client Credentials grant (service-to-service / M2M). */
ClientCredentials = "client_credentials",
/** Azure Managed Identity via IMDS. */
AzureManagedIdentity = "azure_managed_identity",
}
/**
* OAuth configuration for LanceDB authentication.
*
* This is the public TypeScript OAuth configuration type. The generated
* `NativeOAuthConfig` type has the same runtime shape but is an implementation
* detail of the napi-rs binding.
*
* All token acquisition and refresh is handled in the Rust layer.
* This config is passed through to Rust via napi-rs.
*
* @example Client Credentials (service-to-service):
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* clientSecret: "secret",
* scopes: ["api://lancedb-api/.default"],
* };
* ```
*
* @example Azure Managed Identity:
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* scopes: ["api://lancedb-api/.default"],
* flow: OAuthFlowType.AzureManagedIdentity,
* };
* ```
*/
export interface OAuthConfig {
/**
* OIDC issuer URL or OAuth authority URL.
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
*/
issuerUrl: string;
/** Application / Client ID. */
clientId: string;
/**
* OAuth scopes to request.
* For Azure managed identity, exactly one scope or resource is required.
* For example: `["api://{app_id}/.default"]`
*/
scopes: string[];
/** Authentication flow (default: ClientCredentials). */
flow?: OAuthFlowType;
/** Client secret (required for ClientCredentials). */
clientSecret?: string;
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
managedIdentityClientId?: string;
/**
* Seconds before expiry to trigger proactive refresh (default: 300).
* Keep this well below the token TTL; if it is greater than or equal to
* the TTL, each request refreshes the token.
*/
refreshBufferSecs?: number;
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.31.0-beta.1",
"version": "0.31.0-beta.3",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -112,6 +112,12 @@ impl Connection {
builder = builder.client_config(rust_config);
if let Some(oauth_config) = options.oauth_config {
let config: lancedb::remote::oauth::OAuthConfig =
oauth_config.try_into().default_error()?;
builder = builder.oauth_config(config);
}
if let Some(api_key) = options.api_key {
builder = builder.api_key(&api_key);
}

View File

@@ -65,6 +65,11 @@ pub struct ConnectionOptions {
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
/// for testing purposes.
pub host_override: Option<String>,
/// (For LanceDB cloud only): OAuth configuration for IdP-based
/// authentication (e.g., Azure Entra ID). When set, token acquisition
/// and refresh are handled entirely in Rust. TypeScript users should pass
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
pub oauth_config: Option<remote::OAuthConfig>,
}
#[napi(object)]

View File

@@ -3,6 +3,7 @@
use std::collections::HashMap;
use lancedb::error::Error;
use napi_derive::*;
/// Timeout configuration for remote HTTP client.
@@ -140,6 +141,84 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
}
}
/// OAuth configuration for LanceDB authentication.
///
/// This is the generated napi-rs binding shape. TypeScript users should prefer
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
///
/// All token acquisition and refresh is handled in the Rust layer.
#[napi(object)]
#[derive(Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// OAuth scopes to request. For Azure managed identity, exactly one scope
/// or resource is required. For example: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow: "client_credentials" or "azure_managed_identity"
pub flow: Option<String>,
/// Client secret (required for client_credentials).
pub client_secret: Option<String>,
/// Client ID for user-assigned managed identity (azure_managed_identity).
pub managed_identity_client_id: Option<String>,
/// Seconds before expiry to trigger proactive refresh (default: 300).
/// Keep this well below the token TTL; if it is greater than or equal to
/// the TTL, each request refreshes the token.
pub refresh_buffer_secs: Option<u32>,
}
impl std::fmt::Debug for OAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthConfig")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field("scopes", &self.scopes)
.field("flow", &self.flow)
.field(
"client_secret",
&self.client_secret.as_deref().map(|_| "<redacted>"),
)
.field(
"managed_identity_client_id",
&self.managed_identity_client_id,
)
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
.finish()
}
}
impl TryFrom<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
type Error = Error;
fn try_from(config: OAuthConfig) -> Result<Self, Self::Error> {
use lancedb::remote::oauth::OAuthFlow;
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
"client_credentials" => OAuthFlow::ClientCredentials,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: config.managed_identity_client_id,
},
other => {
return Err(Error::InvalidInput {
message: format!("Unknown OAuth flow type: {other}"),
});
}
};
Ok(Self {
issuer_url: config.issuer_url,
client_id: config.client_id,
client_secret: config.client_secret,
scopes: config.scopes,
flow,
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
})
}
}
impl From<ClientConfig> for lancedb::remote::ClientConfig {
fn from(config: ClientConfig) -> Self {
Self {
@@ -156,3 +235,45 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unknown_oauth_flow_returns_invalid_input() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
scopes: vec!["scope".to_string()],
flow: Some("typo".to_string()),
client_secret: None,
managed_identity_client_id: None,
refresh_buffer_secs: None,
};
let err = lancedb::remote::oauth::OAuthConfig::try_from(config).unwrap_err();
assert!(matches!(
err,
Error::InvalidInput { message }
if message == "Unknown OAuth flow type: typo"
));
}
#[test]
fn test_oauth_config_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
scopes: vec!["scope".to_string()],
flow: Some("client_credentials".to_string()),
client_secret: Some("super-secret".to_string()),
managed_identity_client_id: None,
refresh_buffer_secs: None,
};
let debug = format!("{config:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.34.0-beta.2"
current_version = "0.34.0-beta.3"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.34.0-beta.2"
version = "0.34.0-beta.3"
publish = false
edition.workspace = true
description = "Python bindings for LanceDB"

View File

@@ -81,6 +81,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
warnings.warn(
"use_token_pooling is deprecated, use pooling_strategy=None instead",
DeprecationWarning,
stacklevel=2,
)
self.pooling_strategy = None

View File

@@ -373,6 +373,19 @@ def _convert_pyarrow_schema_to_json(schema: pa.Schema) -> JsonArrowSchema:
return JsonArrowSchema(fields=fields, metadata=meta)
def _builds_namespace_natively(
namespace_client_impl: Optional[str],
namespace_client_properties: Optional[Dict[str, str]],
) -> bool:
"""Whether ``connect_namespace_client`` builds the namespace client natively
in Rust (installing the read-freshness context provider) rather than wrapping
the pre-built Python client.
Must mirror Rust ``build_namespace_natively`` in ``python/src/connection.rs``.
"""
return namespace_client_impl == "rest" and bool(namespace_client_properties)
class LanceNamespaceDBConnection(DBConnection):
"""
A LanceDB connection that uses a namespace for table management.
@@ -432,6 +445,13 @@ class LanceNamespaceDBConnection(DBConnection):
)
self._namespace_client_impl = namespace_client_impl
self._namespace_client_properties = namespace_client_properties
# When the namespace client is built natively (see Rust
# ``build_namespace_natively``), the underlying Rust table performs
# QueryTable pushdown through the read-freshness context provider, which
# the pure-Python ``query_table`` path bypasses.
self._route_pushdown_to_rust = _builds_namespace_natively(
namespace_client_impl, namespace_client_properties
)
self._inner = AsyncConnection(
_connect_namespace_client(
namespace_client,
@@ -543,6 +563,7 @@ class LanceNamespaceDBConnection(DBConnection):
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
_async=async_table,
)
@@ -580,6 +601,7 @@ class LanceNamespaceDBConnection(DBConnection):
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
_async=async_table,
)
if branch is not None:
@@ -875,6 +897,8 @@ class AsyncLanceNamespaceDBConnection:
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
namespace_client_pushdown_operations: Optional[List[str]] = None,
namespace_client_impl: Optional[str] = None,
namespace_client_properties: Optional[Dict[str, str]] = None,
):
"""
Initialize an async namespace-based LanceDB connection.
@@ -900,6 +924,12 @@ class AsyncLanceNamespaceDBConnection:
namespace.create_table() instead of using declare_table + local write.
Default is None (no pushdown, all operations run locally).
namespace_client_impl : Optional[str]
The namespace implementation name used to create this connection.
Required (with ``namespace_client_properties``) for the Rust client to
be built natively and install the read-freshness provider.
namespace_client_properties : Optional[Dict[str, str]]
The namespace properties used to create this connection.
"""
self._namespace_client = namespace_client
self.read_consistency_interval = read_consistency_interval
@@ -908,6 +938,14 @@ class AsyncLanceNamespaceDBConnection:
self._namespace_client_pushdown_operations = set(
namespace_client_pushdown_operations or []
)
self._namespace_client_impl = namespace_client_impl
self._namespace_client_properties = namespace_client_properties
# See LanceNamespaceDBConnection: when built natively the Rust table runs
# QueryTable pushdown through the read-freshness provider, so defer to it
# rather than the urllib3 client (which omits x-lancedb-min-timestamp).
self._route_pushdown_to_rust = _builds_namespace_natively(
namespace_client_impl, namespace_client_properties
)
self._inner = AsyncConnection(
_connect_namespace_client(
namespace_client,
@@ -921,8 +959,8 @@ class AsyncLanceNamespaceDBConnection:
namespace_client_pushdown_operations=(
list(self._namespace_client_pushdown_operations)
),
namespace_client_impl=None,
namespace_client_properties=None,
namespace_client_impl=namespace_client_impl,
namespace_client_properties=namespace_client_properties,
)
)
@@ -992,6 +1030,7 @@ class AsyncLanceNamespaceDBConnection:
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
)
async def open_table(
@@ -1029,6 +1068,7 @@ class AsyncLanceNamespaceDBConnection:
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
)
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
@@ -1387,4 +1427,6 @@ def connect_namespace_async(
storage_options=storage_options,
session=session,
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
namespace_client_impl=namespace_client_impl,
namespace_client_properties=namespace_client_properties,
)

View File

@@ -124,6 +124,7 @@ class RemoteDBConnection(DBConnection):
"request_thread_pool is no longer used and will be removed in "
"a future release.",
DeprecationWarning,
stacklevel=2,
)
if connection_timeout is not None:
@@ -132,6 +133,7 @@ class RemoteDBConnection(DBConnection):
"release. Please use client_config.timeout_config.connect_timeout "
"instead.",
DeprecationWarning,
stacklevel=2,
)
client_config.timeout_config.connect_timeout = timedelta(
seconds=connection_timeout
@@ -142,6 +144,7 @@ class RemoteDBConnection(DBConnection):
"read_timeout is deprecated and will be removed in a future release. "
"Please use client_config.timeout_config.read_timeout instead.",
DeprecationWarning,
stacklevel=2,
)
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)

View File

@@ -845,7 +845,8 @@ class RemoteTable(Table):
"""
warnings.warn(
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
"Tables are automatically cleaned up and optimized."
"Tables are automatically cleaned up and optimized.",
stacklevel=2,
)
pass
@@ -857,7 +858,8 @@ class RemoteTable(Table):
"""
warnings.warn(
"compact_files() is a no-op on LanceDB Cloud. "
"Tables are automatically compacted and optimized."
"Tables are automatically compacted and optimized.",
stacklevel=2,
)
pass
@@ -874,7 +876,8 @@ class RemoteTable(Table):
"""
warnings.warn(
"optimize() is a no-op on LanceDB Cloud. "
"Indices are optimized automatically."
"Indices are optimized automatically.",
stacklevel=2,
)
pass

View File

@@ -2022,6 +2022,7 @@ class LanceTable(Table):
namespace_client: Optional[Any] = None,
managed_versioning: Optional[bool] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
_async: AsyncTable = None,
):
if namespace_path is None:
@@ -2031,6 +2032,14 @@ class LanceTable(Table):
self._location = location # Store location for use in _dataset_path
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
# When the connection built the namespace client natively (e.g. an
# enterprise "rest" connection), the underlying Rust table already
# executes QueryTable pushdown itself -- and, unlike this Python urllib3
# path, it routes through the read-freshness context provider that emits
# the ``x-lancedb-min-timestamp`` header. So we must defer pushdown to
# Rust instead of calling the Python ``namespace_client.query_table``
# directly, or reads silently bypass read-freshness (stale results).
self._route_pushdown_to_rust = route_pushdown_to_rust
if _async is not None:
self._table = _async
else:
@@ -2241,6 +2250,7 @@ class LanceTable(Table):
namespace_path=self._namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
location=self._location,
_async=async_table,
)
@@ -2391,8 +2401,11 @@ class LanceTable(Table):
Returns
-------
pa.Table"""
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
):
return self._execute_query(Query()).read_all()
@@ -3344,6 +3357,7 @@ class LanceTable(Table):
location: Optional[str] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
):
"""
Create a new table.
@@ -3406,21 +3420,24 @@ class LanceTable(Table):
self._location = location
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._route_pushdown_to_rust = route_pushdown_to_rust
if data_storage_version is not None:
warnings.warn(
"setting data_storage_version directly on create_table is deprecated. ",
"setting data_storage_version directly on create_table is deprecated. "
"Use database_options instead.",
DeprecationWarning,
stacklevel=2,
)
if storage_options is None:
storage_options = {}
storage_options["new_table_data_storage_version"] = data_storage_version
if enable_v2_manifest_paths is not None:
warnings.warn(
"setting enable_v2_manifest_paths directly on create_table is ",
"setting enable_v2_manifest_paths directly on create_table is "
"deprecated. Use database_options instead.",
DeprecationWarning,
stacklevel=2,
)
if storage_options is None:
storage_options = {}
@@ -3517,6 +3534,7 @@ class LanceTable(Table):
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
and self.current_branch() is None
):
from lancedb.namespace import _execute_server_side_query
@@ -4258,6 +4276,7 @@ class AsyncTable:
namespace_path: Optional[List[str]] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
):
"""Create a new AsyncTable object.
@@ -4270,6 +4289,9 @@ class AsyncTable:
self._namespace_path = namespace_path or []
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
# See LanceTable.__init__: defer QueryTable pushdown to Rust (which emits
# the read-freshness header) for natively-built namespace clients.
self._route_pushdown_to_rust = route_pushdown_to_rust
def _set_namespace_context(
self,
@@ -4277,10 +4299,12 @@ class AsyncTable:
namespace_path: Optional[List[str]] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
) -> "AsyncTable":
self._namespace_path = namespace_path or []
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._route_pushdown_to_rust = route_pushdown_to_rust
return self
def __repr__(self):
@@ -4490,8 +4514,11 @@ class AsyncTable:
-------
pa.Table
"""
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
):
return (await self._execute_query(Query())).read_all()
@@ -5175,8 +5202,11 @@ class AsyncTable:
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
):
from lancedb.namespace import _execute_server_side_query
@@ -5662,6 +5692,7 @@ class AsyncTable:
"The 'retrain' parameter is deprecated and will be removed in a "
"future version.",
DeprecationWarning,
stacklevel=2,
)
return await self._inner.optimize(

View File

@@ -65,6 +65,9 @@ def _namespace_lance_table(namespace_client: _NamespaceClient) -> LanceTable:
table._namespace_path = ["geneva"]
table._namespace_client = namespace_client
table._pushdown_operations = {"QueryTable"}
# This test exercises the Python-side pushdown path (non-native client), so
# pushdown is not routed to Rust.
table._route_pushdown_to_rust = False
return table
@@ -805,6 +808,37 @@ class TestPushdownOperations:
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
assert len(db._namespace_client_pushdown_operations) == 0
def test_route_pushdown_to_rust_for_native_rest(self):
"""A natively-built rest connection must defer QueryTable pushdown to
Rust so reads carry the x-lancedb-min-timestamp read-freshness header."""
db = lancedb.connect_namespace(
"rest",
{"uri": "http://localhost:12345"},
namespace_client_pushdown_operations=["QueryTable"],
)
assert db._route_pushdown_to_rust is True
def test_route_pushdown_to_rust_false_for_dir(self):
"""A non-native (dir) connection keeps the Python pushdown path."""
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
assert db._route_pushdown_to_rust is False
def test_async_route_pushdown_to_rust_for_native_rest(self):
"""The async connection must not silently bypass the read-freshness fix:
a natively-built rest connection defers pushdown to Rust (regression test
for the async path omitting the freshness header)."""
db = lancedb.connect_namespace_async(
"rest",
{"uri": "http://localhost:12345"},
namespace_client_pushdown_operations=["QueryTable"],
)
assert db._route_pushdown_to_rust is True
def test_async_route_pushdown_to_rust_false_for_dir(self):
"""The async non-native (dir) connection keeps the Python pushdown path."""
db = lancedb.connect_namespace_async("dir", {"root": self.temp_dir})
assert db._route_pushdown_to_rust is False
def test_lance_table_to_arrow_uses_query_pushdown(self):
namespace_client = _NamespaceClient()
table = _namespace_lance_table(namespace_client)

View File

@@ -610,24 +610,38 @@ pub fn connect_namespace_client(
namespace_client_impl: Option<String>,
namespace_client_properties: Option<HashMap<String, String>>,
) -> PyResult<Connection> {
let namespace_client = extract_namespace_arc(py, namespace_client)?;
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
let namespace_client_pushdown_operations =
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string());
let ns_properties = namespace_client_properties.unwrap_or_default();
let storage_options = storage_options.unwrap_or_default();
let session = session.map(|s| s.inner.clone());
let database = LanceNamespaceDatabase::from_namespace_client(
namespace_client,
ns_impl,
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
);
// Prefer building the namespace natively from (impl, properties) so the
// read-freshness provider installed
let database = if build_namespace_natively(namespace_client_impl.as_deref(), &ns_properties) {
let ns_impl = namespace_client_impl.expect("impl present per build_namespace_natively");
crate::runtime::block_on(LanceNamespaceDatabase::connect(
&ns_impl,
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
))
.infer_error()?
} else {
let namespace_client = extract_namespace_arc(py, namespace_client)?;
LanceNamespaceDatabase::from_namespace_client(
namespace_client,
namespace_client_impl.unwrap_or_else(|| "python".to_string()),
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
)
};
Ok(Connection::new(LanceConnection::new(
Arc::new(database),
@@ -635,6 +649,16 @@ pub fn connect_namespace_client(
)))
}
/// Whether to build the namespace natively (from impl + properties) instead of
/// wrapping a pre-built client. Native construction is required for the
/// read-freshness provider to be installed
fn build_namespace_natively(
namespace_client_impl: Option<&str>,
namespace_client_properties: &HashMap<String, String>,
) -> bool {
matches!(namespace_client_impl, Some("rest")) && !namespace_client_properties.is_empty()
}
#[derive(FromPyObject)]
pub struct PyClientConfig {
user_agent: String,
@@ -733,3 +757,36 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn props(pairs: &[(&str, &str)]) -> HashMap<String, String> {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn native_build_only_for_rest_with_properties() {
let rest = props(&[("uri", "http://localhost:10024")]);
// rest + non-empty properties -> build natively (installs the
// read-freshness provider so checkout_latest() busts the server cache).
assert!(build_namespace_natively(Some("rest"), &rest));
// dir is local (no server cache) -> wrap the pre-built client unchanged.
assert!(!build_namespace_natively(
Some("dir"),
&props(&[("root", "/tmp")])
));
// No impl: only a pre-built client was handed in -> wrap it as-is.
assert!(!build_namespace_natively(None, &rest));
// rest but no properties: nothing to build a connection from -> wrap.
assert!(!build_namespace_natively(Some("rest"), &HashMap::new()));
}
}

View File

@@ -56,6 +56,15 @@ fn get_runtime() -> &'static runtime::Runtime {
unsafe { &*new_ptr }
}
/// Block the current thread on a future using the shared runtime.
///
/// For sync `#[pyfunction]`s that need to drive an async operation (e.g.
/// building a namespace client). Must not be called from within the runtime's
/// own worker threads.
pub fn block_on<F: std::future::Future>(fut: F) -> F::Output {
get_runtime().block_on(fut)
}
/// Runs in async-signal context after `fork()` in the child. We can only
/// touch atomics here; we deliberately leak the previous runtime because
/// dropping a tokio `Runtime` would try to join its (now-dead) worker

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.31.0-beta.1"
version = "0.31.0-beta.3"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -50,7 +50,7 @@ lance-namespace = { workspace = true }
lance-namespace-impls = { workspace = true }
moka = { workspace = true }
pin-project = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
tokio = { version = "1.23", features = ["rt-multi-thread", "sync"] }
log.workspace = true
async-trait = "0"
bytes = "1"
@@ -75,6 +75,7 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"stream",
], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
urlencoding = { version = "2", optional = true }
uuid = { version = "1.7.0", features = ["v4", "v5"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
@@ -93,6 +94,7 @@ semver = { workspace = true }
anyhow = "1"
tempfile = "3.5.0"
random_word = { version = "0.4.3", features = ["en"] }
tokio = { version = "1.23", features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
aws-sdk-dynamodb = { version = "1.55.0" }
@@ -129,7 +131,13 @@ huggingface = [
"lance-namespace-impls/dir-huggingface",
]
dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
remote = [
"dep:reqwest",
"dep:http",
"dep:urlencoding",
"lance-namespace-impls/rest",
"lance-namespace-impls/rest-adapter",
]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -576,6 +576,9 @@ impl Connection {
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
/// For RemoteDatabase, it is the equivalent RestNamespace.
///
/// Remote connections using dynamic headers forward them through the
/// namespace client's per-request context provider.
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
self.internal.namespace_client().await
}
@@ -584,6 +587,9 @@ impl Connection {
/// Returns (impl_type, properties) where:
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
/// - properties: configuration properties for the namespace
///
/// Remote connections using dynamic headers cannot be exported because the
/// namespace client config only carries static headers.
pub async fn namespace_client_config(
&self,
) -> Result<(String, std::collections::HashMap<String, String>)> {
@@ -661,6 +667,8 @@ pub struct ConnectRequest {
pub struct ConnectBuilder {
request: ConnectRequest,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
#[cfg(feature = "remote")]
oauth_config: Option<crate::remote::OAuthConfig>,
}
#[cfg(feature = "remote")]
@@ -682,6 +690,8 @@ impl ConnectBuilder {
session: None,
},
embedding_registry: None,
#[cfg(feature = "remote")]
oauth_config: None,
}
}
@@ -770,6 +780,19 @@ impl ConnectBuilder {
self
}
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
///
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
/// from the given config and sets it as the header provider. OAuth cannot
/// be combined with an API key or another header provider.
///
/// Token acquisition and refresh are handled in Rust.
#[cfg(feature = "remote")]
pub fn oauth_config(mut self, config: crate::remote::OAuthConfig) -> Self {
self.oauth_config = Some(config);
self
}
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
self.embedding_registry = Some(registry);
@@ -915,9 +938,40 @@ impl ConnectBuilder {
let region = options.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(),
})?;
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
let api_key = match (&self.oauth_config, &options.api_key) {
(Some(_), None) => String::new(),
(Some(_), Some(_)) => {
return Err(Error::InvalidInput {
message:
"api_key and oauth_config cannot both be set when connecting to LanceDb Cloud"
.to_string(),
});
}
(None, Some(key)) => key.clone(),
(None, None) => {
return Err(Error::InvalidInput {
message:
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
.to_string(),
});
}
};
if self.oauth_config.is_some() && self.request.client_config.header_provider.is_some() {
return Err(Error::InvalidInput {
message:
"oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud"
.to_string(),
});
}
let mut client_config = self.request.client_config;
if let Some(oauth_config) = self.oauth_config {
let provider = crate::remote::OAuthHeaderProvider::new(oauth_config)?;
client_config.header_provider =
Some(Arc::new(provider) as Arc<dyn crate::remote::HeaderProvider>);
}
let storage_options = StorageOptions(options.storage_options.clone());
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
@@ -925,7 +979,7 @@ impl ConnectBuilder {
&api_key,
&region,
options.host_override,
self.request.client_config,
client_config,
storage_options.into(),
self.request.read_consistency_interval,
)?);
@@ -1234,6 +1288,83 @@ mod tests {
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
}
#[cfg(feature = "remote")]
#[tokio::test]
async fn test_connect_rejects_api_key_with_oauth_config() {
let oauth_config = crate::remote::OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: crate::remote::OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let result = ConnectBuilder::new("db://my-container/my-prefix")
.region("us-east-1")
.api_key("my-api-key")
.oauth_config(oauth_config)
.execute()
.await;
match result {
Err(Error::InvalidInput { message })
if message
== "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud" =>
{}
Err(err) => panic!("expected InvalidInput, got {err:?}"),
Ok(_) => panic!("expected api_key and oauth_config to be rejected"),
}
}
#[cfg(feature = "remote")]
#[tokio::test]
async fn test_connect_rejects_header_provider_with_oauth_config() {
#[derive(Debug)]
struct TestHeaderProvider;
#[async_trait::async_trait]
impl crate::remote::HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
Ok(HashMap::from([(
"authorization".to_string(),
"Bearer token".to_string(),
)]))
}
}
let oauth_config = crate::remote::OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: crate::remote::OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let client_config = crate::remote::ClientConfig {
header_provider: Some(
Arc::new(TestHeaderProvider) as Arc<dyn crate::remote::HeaderProvider>
),
..Default::default()
};
let result = ConnectBuilder::new("db://my-container/my-prefix")
.region("us-east-1")
.client_config(client_config)
.oauth_config(oauth_config)
.execute()
.await;
match result {
Err(Error::InvalidInput { message })
if message
== "oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud" =>
{}
Err(err) => panic!("expected InvalidInput, got {err:?}"),
Ok(_) => panic!("expected header_provider and oauth_config to be rejected"),
}
}
#[cfg(not(windows))]
#[tokio::test]
async fn test_connect_relative() {

View File

@@ -8,6 +8,7 @@
pub(crate) mod client;
pub(crate) mod db;
pub mod oauth;
mod retry;
pub(crate) mod table;
pub(crate) mod util;
@@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};

View File

@@ -459,12 +459,14 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
config: &ClientConfig,
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-api-key"),
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(),
})?,
);
if !api_key.is_empty() {
headers.insert(
HeaderName::from_static("x-api-key"),
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(),
})?,
);
}
if region == "local" {
let host = format!("{}.local.api.lancedb.com", db_name);
headers.insert(
@@ -1005,6 +1007,33 @@ mod tests {
assert!(!config_tls.assert_hostname);
}
#[test]
fn test_default_headers_skip_empty_api_key() {
let headers = RestfulLanceDbClient::<Sender>::default_headers(
"",
"us-east-1",
"db-name",
false,
&RemoteOptions::default(),
None,
&ClientConfig::default(),
)
.unwrap();
assert!(!headers.contains_key("x-api-key"));
let headers = RestfulLanceDbClient::<Sender>::default_headers(
"api-key",
"us-east-1",
"db-name",
false,
&RemoteOptions::default(),
None,
&ClientConfig::default(),
)
.unwrap();
assert_eq!(headers.get("x-api-key").unwrap(), "api-key");
}
// Test implementation of HeaderProvider
#[derive(Debug, Clone)]
struct TestHeaderProvider {

View File

@@ -7,6 +7,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use http::StatusCode;
use lance_io::object_store::StorageOptions;
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
use moka::future::Cache;
use reqwest::header::CONTENT_TYPE;
@@ -26,7 +27,9 @@ use crate::remote::util::stream_as_body;
use crate::table::BaseTable;
use super::ARROW_STREAM_CONTENT_TYPE;
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
use super::client::{
ClientConfig, HeaderProvider, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender,
};
use super::table::RemoteTable;
use super::util::parse_server_version;
@@ -194,10 +197,66 @@ pub struct RemoteDatabase<S: HttpSend = Sender> {
uri: String,
/// Headers to pass to the namespace client for authentication
namespace_headers: HashMap<String, String>,
namespace_context_provider: Option<Arc<dyn DynamicContextProvider>>,
/// TLS configuration for mTLS support
tls_config: Option<super::client::TlsConfig>,
}
#[derive(Clone)]
struct NamespaceHeaderProviderContext {
header_provider: Arc<dyn HeaderProvider>,
}
impl std::fmt::Debug for NamespaceHeaderProviderContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NamespaceHeaderProviderContext")
.field("header_provider", &"Some(...)")
.finish()
}
}
impl DynamicContextProvider for NamespaceHeaderProviderContext {
fn provide_context(&self, _info: &OperationInfo) -> HashMap<String, String> {
let header_provider = Arc::clone(&self.header_provider);
let handle = match std::thread::Builder::new()
.name("lancedb-namespace-headers".to_string())
.spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| Error::Runtime {
message: format!(
"Failed to create runtime for namespace header provider: {e}"
),
})?
.block_on(header_provider.get_headers())
}) {
Ok(handle) => handle,
Err(err) => {
log::warn!("Failed to spawn dynamic namespace header provider thread: {err}");
return HashMap::new();
}
};
let headers = handle.join();
match headers {
Ok(Ok(headers)) => headers
.into_iter()
.map(|(key, value)| (format!("headers.{key}"), value))
.collect(),
Ok(Err(err)) => {
log::warn!("Failed to get dynamic namespace headers: {err}");
HashMap::new()
}
Err(_) => {
log::warn!("Dynamic namespace header provider panicked");
HashMap::new()
}
}
}
}
impl RemoteDatabase {
pub fn try_new(
uri: &str,
@@ -228,6 +287,16 @@ impl RemoteDatabase {
})
.collect();
let namespace_context_provider =
client_config
.header_provider
.as_ref()
.map(|header_provider| {
Arc::new(NamespaceHeaderProviderContext {
header_provider: Arc::clone(header_provider),
}) as Arc<dyn DynamicContextProvider>
});
let client = RestfulLanceDbClient::try_new(
&parsed,
region,
@@ -247,6 +316,7 @@ impl RemoteDatabase {
table_cache,
uri: uri.to_owned(),
namespace_headers,
namespace_context_provider,
tls_config: client_config.tls_config,
})
}
@@ -271,6 +341,7 @@ mod test_utils {
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
namespace_headers: HashMap::new(),
namespace_context_provider: None,
tls_config: None,
}
}
@@ -281,11 +352,18 @@ mod test_utils {
T: Into<reqwest::Body>,
{
let client = client_with_handler_and_config(handler, config.clone());
let namespace_context_provider =
config.header_provider.as_ref().map(|header_provider| {
Arc::new(NamespaceHeaderProviderContext {
header_provider: Arc::clone(header_provider),
}) as Arc<dyn DynamicContextProvider>
});
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
namespace_headers: config.extra_headers.clone(),
namespace_context_provider,
tls_config: config.tls_config.clone(),
}
}
@@ -759,9 +837,12 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
// Create a RestNamespace pointing to the same remote host with the same authentication headers
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
.delimiter(&self.client.id_delimiter)
// TODO: support header provider
.headers(self.namespace_headers.clone());
if let Some(context_provider) = &self.namespace_context_provider {
builder = builder.context_provider(Arc::clone(context_provider));
}
// Apply mTLS configuration if present
if let Some(tls_config) = &self.tls_config {
if let Some(cert_file) = &tls_config.cert_file {
@@ -781,6 +862,14 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
}
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
if self.namespace_context_provider.is_some() {
return Err(Error::NotSupported {
message:
"Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead"
.to_string(),
});
}
let mut properties = HashMap::new();
properties.insert("uri".to_string(), self.client.host().to_string());
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
@@ -832,12 +921,13 @@ impl From<StorageOptions> for RemoteOptions {
#[cfg(test)]
mod tests {
use super::build_cache_key;
use super::{NamespaceHeaderProviderContext, build_cache_key};
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use arrow_array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
use crate::connection::ConnectBuilder;
use crate::{
@@ -1702,6 +1792,75 @@ mod tests {
assert!(namespace_client.is_ok());
}
#[test]
fn test_namespace_header_provider_context_maps_headers() {
#[derive(Debug)]
struct TestHeaderProvider;
#[async_trait::async_trait]
impl HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
Ok(HashMap::from([(
"authorization".to_string(),
"Bearer token".to_string(),
)]))
}
}
let context_provider = NamespaceHeaderProviderContext {
header_provider: Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>,
};
let context =
context_provider.provide_context(&OperationInfo::new("list_tables", "namespace"));
assert_eq!(
context.get("headers.authorization"),
Some(&"Bearer token".to_string())
);
}
#[tokio::test]
async fn test_namespace_client_supports_dynamic_headers() {
#[derive(Debug)]
struct TestHeaderProvider;
#[async_trait::async_trait]
impl HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
Ok(HashMap::from([(
"authorization".to_string(),
"Bearer token".to_string(),
)]))
}
}
let client_config = ClientConfig {
header_provider: Some(Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>),
..Default::default()
};
let conn = Connection::new_with_handler_and_config(
|_| {
http::Response::builder()
.status(200)
.body(r#"{"tables": []}"#)
.unwrap()
},
client_config,
);
let namespace_client = conn.namespace_client().await;
assert!(namespace_client.is_ok());
match conn.namespace_client_config().await {
Err(Error::NotSupported { message })
if message.contains("dynamic headers are configured") => {}
Err(err) => panic!("expected NotSupported, got {err:?}"),
Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"),
}
}
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
mod rest_adapter_integration {
use super::*;

View File

@@ -0,0 +1,907 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use log::debug;
use reqwest::Client;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::error::{Error, Result};
use crate::remote::client::HeaderProvider;
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
const DEFAULT_TOKEN_TTL_SECS: u64 = 3600;
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
/// OAuth authentication flow configuration.
#[derive(Debug, Clone)]
pub enum OAuthFlow {
/// Client Credentials grant (service-to-service / M2M).
/// Requires `client_secret` in [`OAuthConfig`].
ClientCredentials,
/// Azure Managed Identity via IMDS.
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
/// IMDS requests bypass proxy settings because the endpoint is link-local.
AzureManagedIdentity {
/// Client ID for user-assigned managed identity.
/// Omit for system-assigned managed identity.
client_id: Option<String>,
},
}
/// OAuth configuration for LanceDB authentication.
///
/// All token acquisition and refresh is handled in the Rust layer.
/// Python and TypeScript bindings expose this as a plain config object.
#[derive(Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// Client secret (required for `ClientCredentials`, optional for others).
pub client_secret: Option<String>,
/// OAuth scopes to request.
/// For Azure managed identity, exactly one scope or resource is required.
/// For example: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow to use.
pub flow: OAuthFlow,
/// Seconds before token expiry to trigger proactive refresh (default: 300).
/// Keep this well below the token TTL; if it is greater than or equal to
/// the TTL, each request refreshes the token.
pub refresh_buffer_secs: Option<u64>,
}
impl std::fmt::Debug for OAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthConfig")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field(
"client_secret",
&self.client_secret.as_deref().map(|_| "<redacted>"),
)
.field("scopes", &self.scopes)
.field("flow", &self.flow)
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
.finish()
}
}
// -- OIDC Discovery --
#[derive(Clone, Debug, Deserialize)]
struct OidcDiscovery {
token_endpoint: String,
}
// -- Token Response --
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
/// Token lifetime in seconds.
/// Some providers (Azure IMDS) return this as a string, so we accept both.
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
expires_in: Option<u64>,
#[serde(default)]
#[allow(dead_code)]
token_type: Option<String>,
}
impl std::fmt::Debug for TokenResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenResponse")
.field("access_token", &"<redacted>")
.field("expires_in", &self.expires_in)
.field("token_type", &self.token_type)
.finish()
}
}
fn deserialize_optional_u64_or_string<'de, D>(
deserializer: D,
) -> std::result::Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct U64OrString;
impl<'de> de::Visitor<'de> for U64OrString {
type Value = Option<u64>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an integer, an integer-valued float, a numeric string, or null")
}
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(Some(v))
}
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
if v < 0 {
return Err(E::custom(format!("invalid expires_in value: {v}")));
}
Ok(Some(v as u64))
}
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
if !v.is_finite() || v < 0.0 || v.fract() != 0.0 || v > u64::MAX as f64 {
return Err(E::custom(format!("invalid expires_in value: {v}")));
}
Ok(Some(v as u64))
}
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
v.parse::<u64>().map(Some).map_err(de::Error::custom)
}
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
}
deserializer.deserialize_any(U64OrString)
}
// -- Internal Token State --
struct TokenState {
access_token: Option<String>,
expires_at: Option<Instant>,
}
impl TokenState {
fn new() -> Self {
Self {
access_token: None,
expires_at: None,
}
}
fn is_expired(&self, buffer: Duration) -> bool {
match (self.access_token.as_ref(), self.expires_at) {
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
(None, _) => true,
(Some(_), None) => true,
}
}
fn update(&mut self, resp: &TokenResponse) {
self.access_token = Some(resp.access_token.clone());
let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS);
self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in));
}
}
#[async_trait]
trait TokenSource: Send + Sync + std::fmt::Debug {
async fn fetch_token(&self) -> Result<TokenResponse>;
}
struct ClientCredentialsSource {
issuer_url: String,
client_id: String,
client_secret: String,
scopes: Vec<String>,
http_client: Client,
discovery: RwLock<Option<OidcDiscovery>>,
}
impl std::fmt::Debug for ClientCredentialsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientCredentialsSource")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field("client_secret", &"<redacted>")
.field("scopes", &self.scopes)
.finish()
}
}
impl ClientCredentialsSource {
fn new(
issuer_url: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self> {
let client_secret = client_secret.ok_or(Error::InvalidInput {
message: "client_secret is required for ClientCredentials flow".to_string(),
})?;
Self::validate_issuer_transport(&issuer_url)?;
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for OAuth: {e}"),
})?;
Ok(Self {
issuer_url,
client_id,
client_secret,
scopes,
http_client,
discovery: RwLock::new(None),
})
}
fn validate_issuer_transport(issuer_url: &str) -> Result<()> {
let issuer = url::Url::parse(issuer_url).map_err(|e| Error::InvalidInput {
message: format!("Invalid OAuth issuer_url: {e}"),
})?;
match issuer.scheme() {
"https" => Ok(()),
"http" if Self::is_loopback_issuer(&issuer) => Ok(()),
_ => Err(Error::InvalidInput {
message:
"ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
.to_string(),
}),
}
}
fn is_loopback_issuer(issuer: &url::Url) -> bool {
let Some(host) = issuer.host_str() else {
return false;
};
host.eq_ignore_ascii_case("localhost")
|| host
.parse::<IpAddr>()
.map(|addr| addr.is_loopback())
.unwrap_or(false)
}
async fn get_discovery(&self) -> Result<OidcDiscovery> {
{
let cached = self.discovery.read().await;
if let Some(ref disc) = *cached {
return Ok(disc.clone());
}
}
let mut cache = self.discovery.write().await;
// Double-check
if let Some(ref disc) = *cache {
return Ok(disc.clone());
}
let discovery_url = format!(
"{}/.well-known/openid-configuration",
self.issuer_url.trim_end_matches('/')
);
debug!("Fetching OIDC discovery from {}", discovery_url);
let resp = self
.http_client
.get(&discovery_url)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to fetch OIDC discovery document: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"OIDC discovery failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse OIDC discovery document: {e}"),
})?;
let result = disc.clone();
*cache = Some(disc);
Ok(result)
}
async fn get_token_endpoint(&self) -> Result<String> {
self.get_discovery().await.map(|disc| disc.token_endpoint)
}
fn scopes_string(&self) -> String {
self.scopes.join(" ")
}
async fn post_token_request(
&self,
endpoint: &str,
params: &[(&str, &str)],
) -> Result<TokenResponse> {
let resp = self
.http_client
.post(endpoint)
.form(params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Token request to {endpoint} failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Token request failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse token response: {e}"),
})
}
}
#[async_trait]
impl TokenSource for ClientCredentialsSource {
async fn fetch_token(&self) -> Result<TokenResponse> {
let token_endpoint = self.get_token_endpoint().await?;
let scope = self.scopes_string();
let params = [
("grant_type", "client_credentials"),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("scope", scope.as_str()),
];
self.post_token_request(&token_endpoint, &params).await
}
}
struct AzureImdsSource {
client_id: Option<String>,
resource: String,
http_client: Client,
}
impl std::fmt::Debug for AzureImdsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureImdsSource")
.field("client_id", &self.client_id)
.field("resource", &self.resource)
.finish()
}
}
impl AzureImdsSource {
fn new(scopes: Vec<String>, client_id: Option<String>) -> Result<Self> {
let resource = Self::resource_from_scopes(&scopes)?;
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.no_proxy()
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for Azure IMDS OAuth: {e}"),
})?;
Ok(Self {
client_id,
resource,
http_client,
})
}
fn resource_from_scopes(scopes: &[String]) -> Result<String> {
let [scope] = scopes else {
return Err(Error::InvalidInput {
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
.to_string(),
});
};
Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string())
}
}
#[async_trait]
impl TokenSource for AzureImdsSource {
async fn fetch_token(&self) -> Result<TokenResponse> {
let mut url = format!(
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
urlencoding::encode(&self.resource),
);
if let Some(cid) = self.client_id.as_deref() {
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
}
let resp = self
.http_client
.get(&url)
.header("Metadata", "true")
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Azure IMDS request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Azure IMDS returned status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse IMDS token response: {e}"),
})
}
}
/// OAuth header provider that manages the full token lifecycle.
///
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
/// headers into every LanceDB request, with automatic token refresh.
pub struct OAuthHeaderProvider {
token_source: Box<dyn TokenSource>,
token_state: Arc<RwLock<TokenState>>,
refresh_buffer: Duration,
}
impl std::fmt::Debug for OAuthHeaderProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthHeaderProvider")
.field("token_source", &self.token_source)
.finish()
}
}
impl OAuthHeaderProvider {
/// Create a new OAuth header provider from configuration.
pub fn new(config: OAuthConfig) -> Result<Self> {
let OAuthConfig {
issuer_url,
client_id,
client_secret,
scopes,
flow,
refresh_buffer_secs,
} = config;
if scopes.is_empty() {
return Err(Error::InvalidInput {
message: "At least one OAuth scope is required".to_string(),
});
}
let refresh_buffer =
Duration::from_secs(refresh_buffer_secs.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS));
let token_source: Box<dyn TokenSource> = match flow {
OAuthFlow::ClientCredentials => Box::new(ClientCredentialsSource::new(
issuer_url,
client_id,
client_secret,
scopes,
)?),
OAuthFlow::AzureManagedIdentity { client_id } => {
Box::new(AzureImdsSource::new(scopes, client_id)?)
}
};
Ok(Self {
token_source,
token_state: Arc::new(RwLock::new(TokenState::new())),
refresh_buffer,
})
}
/// Get a valid access token, refreshing if necessary.
async fn get_valid_token(&self) -> Result<String> {
// Fast path: check if current token is still valid
{
let state = self.token_state.read().await;
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
}
// Slow path: acquire or refresh token
let mut state = self.token_state.write().await;
// Double-check after acquiring write lock
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
debug!("Acquiring new OAuth token via {:?}", self.token_source);
let resp = self.token_source.fetch_token().await?;
state.update(&resp);
Ok(resp.access_token)
}
}
#[async_trait]
impl HeaderProvider for OAuthHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
let token = self.get_valid_token().await?;
Ok(HashMap::from([(
"authorization".to_string(),
format!("Bearer {token}"),
)]))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task::JoinHandle;
#[test]
fn test_token_state_expiry() {
let mut state = TokenState::new();
assert!(state.is_expired(Duration::from_secs(0)));
state.access_token = Some("tok".to_string());
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
assert!(!state.is_expired(Duration::from_secs(300)));
assert!(state.is_expired(Duration::from_secs(601)));
state.expires_at = None;
assert!(state.is_expired(Duration::from_secs(0)));
}
#[test]
fn test_token_state_uses_default_expiry() {
let mut state = TokenState::new();
let response = TokenResponse {
access_token: "tok".to_string(),
expires_in: None,
token_type: None,
};
state.update(&response);
assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1)));
assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1)));
}
#[test]
fn test_token_response_accepts_float_expires_in() {
let response: TokenResponse =
serde_json::from_str(r#"{"access_token":"tok","expires_in":3600.0}"#).unwrap();
assert_eq!(response.expires_in, Some(3600));
}
#[test]
fn test_token_response_rejects_negative_expires_in() {
let err =
serde_json::from_str::<TokenResponse>(r#"{"access_token":"tok","expires_in":-1}"#)
.unwrap_err();
assert!(err.to_string().contains("invalid expires_in value: -1"));
}
#[test]
fn test_token_response_debug_redacts_access_token() {
let response = TokenResponse {
access_token: "secret-token".to_string(),
expires_in: Some(3600),
token_type: Some("Bearer".to_string()),
};
let debug = format!("{response:?}");
assert!(!debug.contains("secret-token"));
assert!(debug.contains("access_token: \"<redacted>\""));
}
#[test]
fn test_scopes_string() {
let source = ClientCredentialsSource::new(
"https://login.microsoftonline.com/tenant/v2.0".to_string(),
"app-id".to_string(),
Some("secret".to_string()),
vec!["scope1".to_string(), "scope2".to_string()],
)
.unwrap();
assert_eq!(source.scopes_string(), "scope1 scope2");
}
#[test]
fn test_oauth_config_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("super-secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let debug = format!("{config:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
}
#[test]
fn test_oauth_header_provider_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("super-secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let provider = OAuthHeaderProvider::new(config).unwrap();
let debug = format!("{provider:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: \"<redacted>\""));
}
#[test]
fn test_managed_identity_resource_from_default_scope() {
assert_eq!(
AzureImdsSource::resource_from_scopes(&["api://test/.default".to_string()]).unwrap(),
"api://test"
);
}
#[test]
fn test_managed_identity_resource_without_default_suffix() {
assert_eq!(
AzureImdsSource::resource_from_scopes(&["api://test".to_string()]).unwrap(),
"api://test"
);
}
#[test]
fn test_managed_identity_rejects_multiple_scopes() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![
"api://test-a/.default".to_string(),
"api://test-b/.default".to_string(),
],
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[tokio::test]
async fn test_token_endpoint_requires_discovery_success() {
let (issuer_url, server) = spawn_discovery_error_server().await;
let source = ClientCredentialsSource::new(
issuer_url,
"client-id".to_string(),
Some("secret".to_string()),
vec!["scope".to_string()],
)
.unwrap();
let err = source.get_token_endpoint().await.unwrap_err();
assert!(matches!(
err,
Error::Runtime { message }
if message.contains("OIDC discovery failed with status 503")
));
server.await.unwrap();
}
#[test]
fn test_client_credentials_requires_secret() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[test]
fn test_client_credentials_rejects_insecure_non_loopback_issuer() {
let config = OAuthConfig {
issuer_url: "http://issuer.example.com".to_string(),
client_id: "app-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let err = OAuthHeaderProvider::new(config).unwrap_err();
assert!(matches!(
err,
Error::InvalidInput { message }
if message == "ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
));
}
#[test]
fn test_empty_scopes_rejected() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![],
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[tokio::test]
async fn test_client_credentials_token_lifecycle() {
let (issuer_url, token_requests, server) = spawn_oauth_server().await;
let config = OAuthConfig {
issuer_url,
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: Some(0),
};
let provider = OAuthHeaderProvider::new(config).unwrap();
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
provider.token_state.write().await.expires_at =
Some(Instant::now() - Duration::from_secs(1));
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2");
assert_eq!(token_requests.load(Ordering::SeqCst), 2);
server.await.unwrap();
}
async fn spawn_oauth_server() -> (String, Arc<AtomicUsize>, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let issuer_url = format!("http://{addr}");
let token_requests = Arc::new(AtomicUsize::new(0));
let server_token_requests = Arc::clone(&token_requests);
let server = tokio::spawn(async move {
for _ in 0..3 {
let (mut stream, _) = listener.accept().await.unwrap();
let (request_line, body) = read_http_request(&mut stream).await;
if request_line.starts_with("GET /.well-known/openid-configuration ") {
let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#);
write_json_response(&mut stream, "200 OK", &discovery).await;
} else if request_line.starts_with("POST /token ") {
assert!(body.contains("grant_type=client_credentials"));
assert!(body.contains("client_id=client-id"));
assert!(body.contains("client_secret=secret"));
assert!(body.contains("scope=scope"));
let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1;
let token = format!(
r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"#
);
write_json_response(&mut stream, "200 OK", &token).await;
} else {
write_json_response(&mut stream, "404 Not Found", "{}").await;
}
}
});
(issuer_url, token_requests, server)
}
async fn spawn_discovery_error_server() -> (String, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let issuer_url = format!("http://{addr}");
let server = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let (request_line, _) = read_http_request(&mut stream).await;
assert!(request_line.starts_with("GET /.well-known/openid-configuration "));
write_json_response(&mut stream, "503 Service Unavailable", "{}").await;
});
(issuer_url, server)
}
async fn read_http_request(stream: &mut TcpStream) -> (String, String) {
let mut buffer = Vec::new();
let mut header_end = None;
while header_end.is_none() {
let mut chunk = [0; 1024];
let read = stream.read(&mut chunk).await.unwrap();
assert_ne!(read, 0, "connection closed before request headers");
buffer.extend_from_slice(&chunk[..read]);
header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4);
}
let header_end = header_end.unwrap();
let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
let request_line = headers.lines().next().unwrap_or_default().to_string();
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
while buffer.len() < header_end + content_length {
let mut chunk = [0; 1024];
let read = stream.read(&mut chunk).await.unwrap();
assert_ne!(read, 0, "connection closed before request body");
buffer.extend_from_slice(&chunk[..read]);
}
let body =
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
(request_line, body)
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) {
let response = format!(
"HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).await.unwrap();
}
}

View File

@@ -579,24 +579,45 @@ fn array_to_f32_vec(arr: &Arc<dyn arrow_array::Array>) -> Result<Vec<f32>> {
})
}
/// Magic bytes that prefix (and suffix) the Arrow IPC *file* format.
const ARROW_IPC_FILE_MAGIC: &[u8] = b"ARROW1";
/// Parse Arrow IPC response from the namespace server.
///
/// The server may return either the Arrow IPC *file* format or the *stream*
/// format. REST/phalanx returns the file format (it begins with the `ARROW1`
/// magic); reading that with a `StreamReader` fails with "failed to fill whole
/// buffer". Detect the magic and pick the matching reader so both are handled.
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
use arrow_ipc::reader::StreamReader;
use arrow_ipc::reader::{FileReader, StreamReader};
use std::io::Cursor;
let cursor = Cursor::new(bytes);
let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC response: {}", e),
})?;
// Collect all record batches
let schema = reader.schema();
let batches: Vec<_> = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC batches: {}", e),
let (schema, batches) = if bytes.starts_with(ARROW_IPC_FILE_MAGIC) {
let reader = FileReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC file response: {}", e),
})?;
let schema = reader.schema();
let batches = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC file batches: {}", e),
})?;
(schema, batches)
} else {
let reader =
StreamReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC response: {}", e),
})?;
let schema = reader.schema();
let batches = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC batches: {}", e),
})?;
(schema, batches)
};
// Create a stream from the batches
let stream = futures::stream::iter(batches.into_iter().map(Ok));
@@ -624,6 +645,59 @@ mod tests {
FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap()
}
#[tokio::test]
async fn test_parse_arrow_ipc_response_handles_file_and_stream() {
use arrow_array::{Int32Array, RecordBatch};
use arrow_ipc::writer::{FileWriter, StreamWriter};
use arrow_schema::{DataType, Field, Schema};
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
)
.unwrap();
// Arrow IPC *file* format -- what REST/phalanx returns. Previously this
// failed with "failed to fill whole buffer" because we used a StreamReader.
let mut file_buf = Vec::new();
{
let mut writer = FileWriter::try_new(&mut file_buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}
assert!(file_buf.starts_with(ARROW_IPC_FILE_MAGIC));
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(file_buf))
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap()
.iter()
.map(|b| b.num_rows())
.sum();
assert_eq!(rows, 3);
// Arrow IPC *stream* format must still parse.
let mut stream_buf = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut stream_buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}
assert!(!stream_buf.starts_with(ARROW_IPC_FILE_MAGIC));
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(stream_buf))
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap()
.iter()
.map(|b| b.num_rows())
.sum();
assert_eq!(rows, 3);
}
#[test]
fn test_convert_to_namespace_query_vector() {
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));