From 7cb68bb079e165b83df3f99006ff0703d739a032 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Tue, 12 May 2026 12:53:19 -0700 Subject: [PATCH] feat: add native OAuth/OIDC authentication support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add OAuthConfig and OAuthHeaderProvider to the Rust core with support for five OAuth flows: ClientCredentials, AuthorizationCodePKCE, DeviceCode, AzureManagedIdentity, and WorkloadIdentity. Token acquisition and auto-refresh happen entirely in Rust. Python and TypeScript expose OAuthConfig as a plain config object that maps to the Rust header provider via FFI — no dynamic callbacks cross the language boundary. ConnectBuilder gains an oauth_config() method that replaces the API key requirement when OAuth is configured. --- Cargo.lock | 1 + docs/src/js/enumerations/OAuthFlowType.md | 29 + docs/src/js/globals.md | 3 + docs/src/js/interfaces/ConnectionOptions.md | 13 + docs/src/js/interfaces/NativeOAuthConfig.md | 86 +++ docs/src/js/interfaces/OAuthConfig.md | 109 +++ nodejs/lancedb/index.ts | 3 + nodejs/lancedb/oauth.ts | 72 ++ nodejs/src/connection.rs | 6 + nodejs/src/lib.rs | 5 + nodejs/src/remote.rs | 82 +++ python/python/lancedb/__init__.py | 2 + python/python/lancedb/_lancedb.pyi | 1 + python/python/lancedb/remote/__init__.py | 3 + python/python/lancedb/remote/oauth.py | 73 ++ python/src/connection.rs | 8 +- python/src/lib.rs | 1 + python/src/oauth.rs | 72 ++ rust/lancedb/Cargo.toml | 4 +- rust/lancedb/src/connection.rs | 45 +- rust/lancedb/src/remote.rs | 2 + rust/lancedb/src/remote/client.rs | 43 +- rust/lancedb/src/remote/oauth.rs | 714 ++++++++++++++++++++ 23 files changed, 1365 insertions(+), 12 deletions(-) create mode 100644 docs/src/js/enumerations/OAuthFlowType.md create mode 100644 docs/src/js/interfaces/NativeOAuthConfig.md create mode 100644 docs/src/js/interfaces/OAuthConfig.md create mode 100644 nodejs/lancedb/oauth.ts create mode 100644 python/python/lancedb/remote/oauth.py create mode 100644 python/src/oauth.rs create mode 100644 rust/lancedb/src/remote/oauth.rs diff --git a/Cargo.lock b/Cargo.lock index 03bf4657a..bf5f7bb0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5383,6 +5383,7 @@ dependencies = [ "tokenizers", "tokio", "url", + "urlencoding", "uuid", "walkdir", ] diff --git a/docs/src/js/enumerations/OAuthFlowType.md b/docs/src/js/enumerations/OAuthFlowType.md new file mode 100644 index 000000000..fe546a140 --- /dev/null +++ b/docs/src/js/enumerations/OAuthFlowType.md @@ -0,0 +1,29 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / OAuthFlowType + +# Enumeration: OAuthFlowType + +OAuth authentication flow types. + +## Enumeration Members + +### AzureManagedIdentity + +```ts +AzureManagedIdentity: "azure_managed_identity"; +``` + +Azure Managed Identity via IMDS. + +*** + +### ClientCredentials + +```ts +ClientCredentials: "client_credentials"; +``` + +Client Credentials grant (service-to-service / M2M). diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index 3efa3a360..79d842346 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -12,6 +12,7 @@ ## Enumerations - [FullTextQueryType](enumerations/FullTextQueryType.md) +- [OAuthFlowType](enumerations/OAuthFlowType.md) - [Occur](enumerations/Occur.md) - [Operator](enumerations/Operator.md) @@ -85,6 +86,8 @@ - [ListNamespacesResponse](interfaces/ListNamespacesResponse.md) - [LsmWriteSpec](interfaces/LsmWriteSpec.md) - [MergeResult](interfaces/MergeResult.md) +- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md) +- [OAuthConfig](interfaces/OAuthConfig.md) - [OpenTableOptions](interfaces/OpenTableOptions.md) - [OptimizeOptions](interfaces/OptimizeOptions.md) - [OptimizeStats](interfaces/OptimizeStats.md) diff --git a/docs/src/js/interfaces/ConnectionOptions.md b/docs/src/js/interfaces/ConnectionOptions.md index de2083a9b..42f80f077 100644 --- a/docs/src/js/interfaces/ConnectionOptions.md +++ b/docs/src/js/interfaces/ConnectionOptions.md @@ -64,6 +64,19 @@ client used by manifest-enabled native connections. *** +### oauthConfig? + +```ts +optional oauthConfig: NativeOAuthConfig; +``` + +(For LanceDB cloud only): OAuth configuration for IdP-based +authentication (e.g., Azure Entra ID). When set, token acquisition +and refresh are handled entirely in Rust. TypeScript users should pass +the public `OAuthConfig` type exported from `@lancedb/lancedb`. + +*** + ### readConsistencyInterval? ```ts diff --git a/docs/src/js/interfaces/NativeOAuthConfig.md b/docs/src/js/interfaces/NativeOAuthConfig.md new file mode 100644 index 000000000..b569b75a8 --- /dev/null +++ b/docs/src/js/interfaces/NativeOAuthConfig.md @@ -0,0 +1,86 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / NativeOAuthConfig + +# Interface: NativeOAuthConfig + +OAuth configuration for LanceDB authentication. + +This is the generated napi-rs binding shape. TypeScript users should prefer +the public `OAuthConfig` type exported from `@lancedb/lancedb`. + +All token acquisition and refresh is handled in the Rust layer. + +## Properties + +### clientId + +```ts +clientId: string; +``` + +Application / Client ID. + +*** + +### clientSecret? + +```ts +optional clientSecret: string; +``` + +Client secret (required for client_credentials). + +*** + +### flow? + +```ts +optional flow: string; +``` + +Authentication flow: "client_credentials" or "azure_managed_identity" + +*** + +### issuerUrl + +```ts +issuerUrl: string; +``` + +OIDC issuer URL or OAuth authority URL. +For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + +*** + +### managedIdentityClientId? + +```ts +optional managedIdentityClientId: string; +``` + +Client ID for user-assigned managed identity (azure_managed_identity). + +*** + +### refreshBufferSecs? + +```ts +optional refreshBufferSecs: number; +``` + +Seconds before expiry to trigger proactive refresh (default: 300). + +*** + +### scopes + +```ts +scopes: string[]; +``` + +OAuth scopes to request. For Azure managed identity, exactly one scope +or resource is required. For example: `["api://{app_id}/.default"]` diff --git a/docs/src/js/interfaces/OAuthConfig.md b/docs/src/js/interfaces/OAuthConfig.md new file mode 100644 index 000000000..4485616bd --- /dev/null +++ b/docs/src/js/interfaces/OAuthConfig.md @@ -0,0 +1,109 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / OAuthConfig + +# Interface: OAuthConfig + +OAuth configuration for LanceDB authentication. + +This is the public TypeScript OAuth configuration type. The generated +`NativeOAuthConfig` type has the same runtime shape but is an implementation +detail of the napi-rs binding. + +All token acquisition and refresh is handled in the Rust layer. +This config is passed through to Rust via napi-rs. + +## Examples + +```typescript +const config: OAuthConfig = { + issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0", + clientId: "app-id", + clientSecret: "secret", + scopes: ["api://lancedb-api/.default"], +}; +``` + +```typescript +const config: OAuthConfig = { + issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0", + clientId: "app-id", + scopes: ["api://lancedb-api/.default"], + flow: OAuthFlowType.AzureManagedIdentity, +}; +``` + +## Properties + +### clientId + +```ts +clientId: string; +``` + +Application / Client ID. + +*** + +### clientSecret? + +```ts +optional clientSecret: string; +``` + +Client secret (required for ClientCredentials). + +*** + +### flow? + +```ts +optional flow: OAuthFlowType; +``` + +Authentication flow (default: ClientCredentials). + +*** + +### issuerUrl + +```ts +issuerUrl: string; +``` + +OIDC issuer URL or OAuth authority URL. +For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + +*** + +### managedIdentityClientId? + +```ts +optional managedIdentityClientId: string; +``` + +Client ID for user-assigned managed identity (AzureManagedIdentity). + +*** + +### refreshBufferSecs? + +```ts +optional refreshBufferSecs: number; +``` + +Seconds before expiry to trigger proactive refresh (default: 300). + +*** + +### scopes + +```ts +scopes: string[]; +``` + +OAuth scopes to request. +For Azure managed identity, exactly one scope or resource is required. +For example: `["api://{app_id}/.default"]` diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index c74cf1caa..b9939726d 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -52,6 +52,7 @@ export { SplitHashOptions, SplitSequentialOptions, ShuffleOptions, + OAuthConfig as NativeOAuthConfig, } from "./native.js"; export { @@ -130,6 +131,8 @@ export { TokenResponse, } from "./header"; +export { OAuthConfig, OAuthFlowType } from "./oauth"; + export { MergeInsertBuilder, WriteExecutionOptions } from "./merge"; export * as embedding from "./embedding"; diff --git a/nodejs/lancedb/oauth.ts b/nodejs/lancedb/oauth.ts new file mode 100644 index 000000000..20011b581 --- /dev/null +++ b/nodejs/lancedb/oauth.ts @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +/** + * OAuth authentication flow types. + */ +export enum OAuthFlowType { + /** Client Credentials grant (service-to-service / M2M). */ + ClientCredentials = "client_credentials", + /** Azure Managed Identity via IMDS. */ + AzureManagedIdentity = "azure_managed_identity", +} + +/** + * OAuth configuration for LanceDB authentication. + * + * This is the public TypeScript OAuth configuration type. The generated + * `NativeOAuthConfig` type has the same runtime shape but is an implementation + * detail of the napi-rs binding. + * + * All token acquisition and refresh is handled in the Rust layer. + * This config is passed through to Rust via napi-rs. + * + * @example Client Credentials (service-to-service): + * ```typescript + * const config: OAuthConfig = { + * issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0", + * clientId: "app-id", + * clientSecret: "secret", + * scopes: ["api://lancedb-api/.default"], + * }; + * ``` + * + * @example Azure Managed Identity: + * ```typescript + * const config: OAuthConfig = { + * issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0", + * clientId: "app-id", + * scopes: ["api://lancedb-api/.default"], + * flow: OAuthFlowType.AzureManagedIdentity, + * }; + * ``` + */ +export interface OAuthConfig { + /** + * OIDC issuer URL or OAuth authority URL. + * For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + */ + issuerUrl: string; + + /** Application / Client ID. */ + clientId: string; + + /** + * OAuth scopes to request. + * For Azure managed identity, exactly one scope or resource is required. + * For example: `["api://{app_id}/.default"]` + */ + scopes: string[]; + + /** Authentication flow (default: ClientCredentials). */ + flow?: OAuthFlowType; + + /** Client secret (required for ClientCredentials). */ + clientSecret?: string; + + /** Client ID for user-assigned managed identity (AzureManagedIdentity). */ + managedIdentityClientId?: string; + + /** Seconds before expiry to trigger proactive refresh (default: 300). */ + refreshBufferSecs?: number; +} diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index 18d6644de..1a18651f0 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -112,6 +112,12 @@ impl Connection { builder = builder.client_config(rust_config); + if let Some(oauth_config) = options.oauth_config { + let config: lancedb::remote::oauth::OAuthConfig = + oauth_config.try_into().default_error()?; + builder = builder.oauth_config(config); + } + if let Some(api_key) = options.api_key { builder = builder.api_key(&api_key); } diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index 53c630c93..b95602f7b 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -65,6 +65,11 @@ pub struct ConnectionOptions { /// (For LanceDB cloud only): the host to use for LanceDB cloud. Used /// for testing purposes. pub host_override: Option, + /// (For LanceDB cloud only): OAuth configuration for IdP-based + /// authentication (e.g., Azure Entra ID). When set, token acquisition + /// and refresh are handled entirely in Rust. TypeScript users should pass + /// the public `OAuthConfig` type exported from `@lancedb/lancedb`. + pub oauth_config: Option, } #[napi(object)] diff --git a/nodejs/src/remote.rs b/nodejs/src/remote.rs index 8cfcbc984..b0abd77ec 100644 --- a/nodejs/src/remote.rs +++ b/nodejs/src/remote.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; +use lancedb::error::Error; use napi_derive::*; /// Timeout configuration for remote HTTP client. @@ -140,6 +141,62 @@ impl From for lancedb::remote::TlsConfig { } } +/// OAuth configuration for LanceDB authentication. +/// +/// This is the generated napi-rs binding shape. TypeScript users should prefer +/// the public `OAuthConfig` type exported from `@lancedb/lancedb`. +/// +/// All token acquisition and refresh is handled in the Rust layer. +#[napi(object)] +#[derive(Debug, Clone)] +pub struct OAuthConfig { + /// OIDC issuer URL or OAuth authority URL. + /// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + pub issuer_url: String, + /// Application / Client ID. + pub client_id: String, + /// OAuth scopes to request. For Azure managed identity, exactly one scope + /// or resource is required. For example: `["api://{app_id}/.default"]` + pub scopes: Vec, + /// Authentication flow: "client_credentials" or "azure_managed_identity" + pub flow: Option, + /// Client secret (required for client_credentials). + pub client_secret: Option, + /// Client ID for user-assigned managed identity (azure_managed_identity). + pub managed_identity_client_id: Option, + /// Seconds before expiry to trigger proactive refresh (default: 300). + pub refresh_buffer_secs: Option, +} + +impl TryFrom for lancedb::remote::oauth::OAuthConfig { + type Error = Error; + + fn try_from(config: OAuthConfig) -> Result { + use lancedb::remote::oauth::OAuthFlow; + + let flow = match config.flow.as_deref().unwrap_or("client_credentials") { + "client_credentials" => OAuthFlow::ClientCredentials, + "azure_managed_identity" => OAuthFlow::AzureManagedIdentity { + client_id: config.managed_identity_client_id, + }, + other => { + return Err(Error::InvalidInput { + message: format!("Unknown OAuth flow type: {other}"), + }); + } + }; + + Ok(Self { + issuer_url: config.issuer_url, + client_id: config.client_id, + client_secret: config.client_secret, + scopes: config.scopes, + flow, + refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64), + }) + } +} + impl From for lancedb::remote::ClientConfig { fn from(config: ClientConfig) -> Self { Self { @@ -156,3 +213,28 @@ impl From for lancedb::remote::ClientConfig { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_unknown_oauth_flow_returns_invalid_input() { + let config = OAuthConfig { + issuer_url: "https://issuer.example.com".to_string(), + client_id: "client-id".to_string(), + scopes: vec!["scope".to_string()], + flow: Some("typo".to_string()), + client_secret: None, + managed_identity_client_id: None, + refresh_buffer_secs: None, + }; + + let err = lancedb::remote::oauth::OAuthConfig::try_from(config).unwrap_err(); + assert!(matches!( + err, + Error::InvalidInput { message } + if message == "Unknown OAuth flow type: typo" + )); + } +} diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index e748e1402..7eb3d897e 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -340,6 +340,7 @@ async def connect_async( session: Optional[Session] = None, manifest_enabled: bool = False, namespace_client_properties: Optional[Dict[str, str]] = None, + oauth_config=None, ) -> AsyncConnection: """Connect to a LanceDB database. @@ -435,6 +436,7 @@ async def connect_async( session, manifest_enabled, namespace_client_properties, + oauth_config, ) ) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 8ddb28604..3f3c986dd 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -280,6 +280,7 @@ async def connect( session: Optional[Session], manifest_enabled: bool = False, namespace_client_properties: Optional[Dict[str, str]] = None, + oauth_config: Optional[Any] = None, ) -> Connection: ... class RecordBatchStream: diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 289e28942..a6ef55eb5 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -9,6 +9,7 @@ from typing import List, Optional from lancedb import __version__ from .header import HeaderProvider +from .oauth import OAuthConfig, OAuthFlowType __all__ = [ "TimeoutConfig", @@ -16,6 +17,8 @@ __all__ = [ "TlsConfig", "ClientConfig", "HeaderProvider", + "OAuthConfig", + "OAuthFlowType", ] diff --git a/python/python/lancedb/remote/oauth.py b/python/python/lancedb/remote/oauth.py new file mode 100644 index 000000000..d560f70bc --- /dev/null +++ b/python/python/lancedb/remote/oauth.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + + +class OAuthFlowType(str, Enum): + """OAuth authentication flow types.""" + + CLIENT_CREDENTIALS = "client_credentials" + """Client Credentials grant (service-to-service / M2M).""" + + AZURE_MANAGED_IDENTITY = "azure_managed_identity" + """Azure Managed Identity via IMDS.""" + + +@dataclass +class OAuthConfig: + """OAuth configuration for LanceDB authentication. + + All token acquisition and refresh is handled in the Rust layer. + This config is passed through to Rust via PyO3. + + Parameters + ---------- + issuer_url : str + OIDC issuer URL or OAuth authority URL. + For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0`` + client_id : str + Application / Client ID. + scopes : List[str] + OAuth scopes to request. + For Azure managed identity, exactly one scope or resource is required. + For example: ``["api://{app_id}/.default"]`` + flow : OAuthFlowType + Authentication flow to use. Default: CLIENT_CREDENTIALS. + client_secret : Optional[str] + Client secret (required for CLIENT_CREDENTIALS). + managed_identity_client_id : Optional[str] + Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY). + refresh_buffer_secs : Optional[int] + Seconds before expiry to trigger proactive refresh (default: 300). + + Examples + -------- + Client Credentials (service-to-service): + + >>> config = OAuthConfig( + ... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0", + ... client_id="app-id", + ... client_secret="secret", + ... scopes=["api://lancedb-api/.default"], + ... ) + + Azure Managed Identity: + + >>> config = OAuthConfig( + ... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0", + ... client_id="app-id", + ... scopes=["api://lancedb-api/.default"], + ... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY, + ... ) + """ + + issuer_url: str + client_id: str + scopes: List[str] + flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS + client_secret: Optional[str] = None + managed_identity_client_id: Optional[str] = None + refresh_buffer_secs: Optional[int] = None diff --git a/python/src/connection.rs b/python/src/connection.rs index 51713fc93..d8e8aa914 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -539,7 +539,7 @@ impl Connection { } #[pyfunction] -#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))] +#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))] #[allow(clippy::too_many_arguments)] pub fn connect( py: Python<'_>, @@ -553,6 +553,7 @@ pub fn connect( session: Option, manifest_enabled: bool, namespace_client_properties: Option>, + oauth_config: Option, ) -> PyResult> { future_into_py(py, async move { let mut builder = lancedb::connect(&uri); @@ -582,6 +583,11 @@ pub fn connect( if let Some(client_config) = client_config { builder = builder.client_config(client_config.into()); } + if let Some(oauth_config) = oauth_config { + let config: lancedb::remote::oauth::OAuthConfig = + oauth_config.try_into().infer_error()?; + builder = builder.oauth_config(config); + } if let Some(session) = session { builder = builder.session(session.inner.clone()); } diff --git a/python/src/lib.rs b/python/src/lib.rs index fdf8f5cb7..72043c484 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -26,6 +26,7 @@ pub mod expr; pub mod header; pub mod index; pub mod namespace; +pub mod oauth; pub mod permutation; pub mod query; pub mod runtime; diff --git a/python/src/oauth.rs b/python/src/oauth.rs new file mode 100644 index 000000000..11ea011e2 --- /dev/null +++ b/python/src/oauth.rs @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use pyo3::FromPyObject; + +use lancedb::error::Error; +use lancedb::remote::oauth::{OAuthConfig, OAuthFlow}; + +/// Python-side OAuth configuration, extracted via FromPyObject. +/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass. +#[derive(FromPyObject)] +pub struct PyOAuthConfig { + pub issuer_url: String, + pub client_id: String, + pub scopes: Vec, + pub flow: String, + pub client_secret: Option, + pub managed_identity_client_id: Option, + pub refresh_buffer_secs: Option, +} + +impl TryFrom for OAuthConfig { + type Error = Error; + + fn try_from(py: PyOAuthConfig) -> Result { + let flow = match py.flow.as_str() { + "client_credentials" => OAuthFlow::ClientCredentials, + "azure_managed_identity" => OAuthFlow::AzureManagedIdentity { + client_id: py.managed_identity_client_id, + }, + other => { + return Err(Error::InvalidInput { + message: format!("Unknown OAuth flow type: {other}"), + }); + } + }; + + Ok(Self { + issuer_url: py.issuer_url, + client_id: py.client_id, + client_secret: py.client_secret, + scopes: py.scopes, + flow, + refresh_buffer_secs: py.refresh_buffer_secs, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_unknown_oauth_flow_returns_invalid_input() { + let config = PyOAuthConfig { + issuer_url: "https://issuer.example.com".to_string(), + client_id: "client-id".to_string(), + scopes: vec!["scope".to_string()], + flow: "typo".to_string(), + client_secret: None, + managed_identity_client_id: None, + refresh_buffer_secs: None, + }; + + let err = OAuthConfig::try_from(config).unwrap_err(); + assert!(matches!( + err, + Error::InvalidInput { message } + if message == "Unknown OAuth flow type: typo" + )); + } +} diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index ea5e01df5..9d13be7d9 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -75,6 +75,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [ "stream", ], optional = true } http = { version = "1", optional = true } # Matching what is in reqwest +# OAuth dependencies (used by remote feature) +urlencoding = { version = "2", optional = true } uuid = { version = "1.7.0", features = ["v4", "v5"] } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } @@ -129,7 +131,7 @@ huggingface = [ "lance-namespace-impls/dir-huggingface", ] dynamodb = ["lance/dynamodb", "aws"] -remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"] +remote = ["dep:reqwest", "dep:http", "dep:urlencoding", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] bedrock = ["dep:aws-sdk-bedrockruntime"] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 53b61641b..135b5a341 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -661,6 +661,8 @@ pub struct ConnectRequest { pub struct ConnectBuilder { request: ConnectRequest, embedding_registry: Option>, + #[cfg(feature = "remote")] + oauth_config: Option, } #[cfg(feature = "remote")] @@ -682,6 +684,8 @@ impl ConnectBuilder { session: None, }, embedding_registry: None, + #[cfg(feature = "remote")] + oauth_config: None, } } @@ -770,6 +774,19 @@ impl ConnectBuilder { self } + /// Configure OAuth authentication for LanceDB Cloud/Enterprise. + /// + /// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider) + /// from the given config and sets it as the header provider, replacing any + /// previously configured header provider or API key. + /// + /// Token acquisition and refresh are handled entirely in Rust. + #[cfg(feature = "remote")] + pub fn oauth_config(mut self, config: crate::remote::oauth::OAuthConfig) -> Self { + self.oauth_config = Some(config); + self + } + /// Provide a custom [`EmbeddingRegistry`] to use for this connection. pub fn embedding_registry(mut self, registry: Arc) -> Self { self.embedding_registry = Some(registry); @@ -915,9 +932,29 @@ impl ConnectBuilder { let region = options.region.ok_or_else(|| Error::InvalidInput { message: "A region is required when connecting to LanceDb Cloud".to_string(), })?; - let api_key = options.api_key.ok_or_else(|| Error::InvalidInput { - message: "An api_key is required when connecting to LanceDb Cloud".to_string(), - })?; + + // When OAuth is configured, api_key is not required + let api_key = match (&self.oauth_config, &options.api_key) { + (Some(_), None) => String::new(), + (Some(_), Some(key)) => key.clone(), + (None, Some(key)) => key.clone(), + (None, None) => { + return Err(Error::InvalidInput { + message: + "An api_key or oauth_config is required when connecting to LanceDb Cloud" + .to_string(), + }); + } + }; + + let mut client_config = self.request.client_config; + + // Apply OAuth header provider if configured + if let Some(oauth_config) = self.oauth_config { + let provider = crate::remote::oauth::OAuthHeaderProvider::new(oauth_config)?; + client_config.header_provider = + Some(Arc::new(provider) as Arc); + } let storage_options = StorageOptions(options.storage_options.clone()); let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( @@ -925,7 +962,7 @@ impl ConnectBuilder { &api_key, ®ion, options.host_override, - self.request.client_config, + client_config, storage_options.into(), self.request.read_consistency_interval, )?); diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index 866ecdcd8..25e19c537 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -8,6 +8,7 @@ pub(crate) mod client; pub(crate) mod db; +pub mod oauth; mod retry; pub(crate) mod table; pub(crate) mod util; @@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json"; pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig}; pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder}; +pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider}; diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 6a44f7f1c..e66f7e36f 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -459,12 +459,14 @@ impl RestfulLanceDbClient { config: &ClientConfig, ) -> Result { let mut headers = HeaderMap::new(); - headers.insert( - HeaderName::from_static("x-api-key"), - HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput { - message: "non-ascii api key provided".to_string(), - })?, - ); + if !api_key.is_empty() { + headers.insert( + HeaderName::from_static("x-api-key"), + HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput { + message: "non-ascii api key provided".to_string(), + })?, + ); + } if region == "local" { let host = format!("{}.local.api.lancedb.com", db_name); headers.insert( @@ -1037,6 +1039,35 @@ mod tests { } } + #[test] + fn test_default_headers_skip_empty_api_key() { + let headers = RestfulLanceDbClient::::default_headers( + "", + "us-east-1", + "db", + false, + &RemoteOptions(HashMap::new()), + None, + &ClientConfig::default(), + ) + .unwrap(); + + assert!(!headers.contains_key("x-api-key")); + + let headers = RestfulLanceDbClient::::default_headers( + "api-key", + "us-east-1", + "db", + false, + &RemoteOptions(HashMap::new()), + None, + &ClientConfig::default(), + ) + .unwrap(); + + assert_eq!(headers.get("x-api-key").unwrap(), "api-key"); + } + #[tokio::test] async fn test_client_config_with_header_provider() { let mut headers = HashMap::new(); diff --git a/rust/lancedb/src/remote/oauth.rs b/rust/lancedb/src/remote/oauth.rs new file mode 100644 index 000000000..b6c98ebcd --- /dev/null +++ b/rust/lancedb/src/remote/oauth.rs @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use log::{debug, warn}; +use reqwest::Client; +use serde::Deserialize; +use tokio::sync::RwLock; + +use crate::error::{Error, Result}; +use crate::remote::client::HeaderProvider; + +const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300; +const DEFAULT_TOKEN_TTL_SECS: u64 = 3600; +const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token"; +const AZURE_IMDS_API_VERSION: &str = "2018-02-01"; + +/// OAuth authentication flow configuration. +#[derive(Debug, Clone)] +pub enum OAuthFlow { + /// Client Credentials grant (service-to-service / M2M). + /// Requires `client_secret` in [`OAuthConfig`]. + ClientCredentials, + + /// Azure Managed Identity via IMDS. + /// Works on Azure VMs, AKS, App Service, and Azure Functions. + AzureManagedIdentity { + /// Client ID for user-assigned managed identity. + /// Omit for system-assigned managed identity. + client_id: Option, + }, +} + +/// OAuth configuration for LanceDB authentication. +/// +/// All token acquisition and refresh is handled in the Rust layer. +/// Python and TypeScript bindings expose this as a plain config object. +#[derive(Debug, Clone)] +pub struct OAuthConfig { + /// OIDC issuer URL or OAuth authority URL. + /// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + pub issuer_url: String, + + /// Application / Client ID. + pub client_id: String, + + /// Client secret (required for `ClientCredentials`, optional for others). + pub client_secret: Option, + + /// OAuth scopes to request. + /// For Azure managed identity, exactly one scope or resource is required. + /// For example: `["api://{app_id}/.default"]` + pub scopes: Vec, + + /// Authentication flow to use. + pub flow: OAuthFlow, + + /// Seconds before token expiry to trigger proactive refresh (default: 300). + pub refresh_buffer_secs: Option, +} + +// -- OIDC Discovery -- + +#[derive(Debug, Deserialize)] +struct OidcDiscovery { + token_endpoint: String, +} + +// -- Token Response -- + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + /// Token lifetime in seconds. + /// Some providers (Azure IMDS) return this as a string, so we accept both. + #[serde(default, deserialize_with = "deserialize_optional_u64_or_string")] + expires_in: Option, + #[serde(default)] + #[allow(dead_code)] + token_type: Option, +} + +fn deserialize_optional_u64_or_string<'de, D>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::de; + + struct U64OrString; + impl<'de> de::Visitor<'de> for U64OrString { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a u64, a numeric string, or null") + } + + fn visit_u64(self, v: u64) -> std::result::Result { + Ok(Some(v)) + } + + fn visit_i64(self, v: i64) -> std::result::Result { + Ok(Some(v as u64)) + } + + fn visit_str(self, v: &str) -> std::result::Result { + v.parse::().map(Some).map_err(de::Error::custom) + } + + fn visit_none(self) -> std::result::Result { + Ok(None) + } + + fn visit_unit(self) -> std::result::Result { + Ok(None) + } + } + + deserializer.deserialize_any(U64OrString) +} + +// -- Internal Token State -- + +struct TokenState { + access_token: Option, + expires_at: Option, +} + +impl TokenState { + fn new() -> Self { + Self { + access_token: None, + expires_at: None, + } + } + + fn is_expired(&self, buffer: Duration) -> bool { + match (self.access_token.as_ref(), self.expires_at) { + (Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at, + (None, _) => true, + (Some(_), None) => true, + } + } + + fn update(&mut self, resp: &TokenResponse) { + self.access_token = Some(resp.access_token.clone()); + let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS); + self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in)); + } +} + +/// OAuth header provider that manages the full token lifecycle. +/// +/// Implements [`HeaderProvider`] to inject `Authorization: Bearer ` +/// headers into every LanceDB request, with automatic token refresh. +pub struct OAuthHeaderProvider { + config: OAuthConfig, + http_client: Client, + token_state: Arc>, + /// Cached OIDC discovery document + discovery: Arc>>, + refresh_buffer: Duration, +} + +impl std::fmt::Debug for OAuthHeaderProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OAuthHeaderProvider") + .field("issuer_url", &self.config.issuer_url) + .field("client_id", &self.config.client_id) + .field("flow", &self.config.flow) + .finish() + } +} + +impl OAuthHeaderProvider { + /// Create a new OAuth header provider from configuration. + pub fn new(config: OAuthConfig) -> Result { + // Validate config upfront + if matches!(config.flow, OAuthFlow::ClientCredentials) && config.client_secret.is_none() { + return Err(Error::InvalidInput { + message: "client_secret is required for ClientCredentials flow".to_string(), + }); + } + if config.scopes.is_empty() { + return Err(Error::InvalidInput { + message: "At least one OAuth scope is required".to_string(), + }); + } + if matches!(config.flow, OAuthFlow::AzureManagedIdentity { .. }) && config.scopes.len() != 1 + { + return Err(Error::InvalidInput { + message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource" + .to_string(), + }); + } + + let http_client = Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| Error::Runtime { + message: format!("Failed to create HTTP client for OAuth: {e}"), + })?; + + let refresh_buffer = Duration::from_secs( + config + .refresh_buffer_secs + .unwrap_or(DEFAULT_REFRESH_BUFFER_SECS), + ); + + Ok(Self { + config, + http_client, + token_state: Arc::new(RwLock::new(TokenState::new())), + discovery: Arc::new(RwLock::new(None)), + refresh_buffer, + }) + } + + /// Get a valid access token, refreshing if necessary. + async fn get_valid_token(&self) -> Result { + // Fast path: check if current token is still valid + { + let state = self.token_state.read().await; + if !state.is_expired(self.refresh_buffer) + && let Some(ref token) = state.access_token + { + return Ok(token.clone()); + } + } + + // Slow path: acquire or refresh token + let mut state = self.token_state.write().await; + + // Double-check after acquiring write lock + if !state.is_expired(self.refresh_buffer) + && let Some(ref token) = state.access_token + { + return Ok(token.clone()); + } + + debug!("Acquiring new OAuth token via {:?} flow", self.config.flow); + let resp = self.acquire_token().await?; + + state.update(&resp); + Ok(resp.access_token) + } + + /// Acquire a new token using the configured flow. + async fn acquire_token(&self) -> Result { + match &self.config.flow { + OAuthFlow::ClientCredentials => self.acquire_client_credentials().await, + OAuthFlow::AzureManagedIdentity { client_id } => { + self.acquire_managed_identity(client_id.as_deref()).await + } + } + } + + // -- OIDC Discovery -- + + async fn get_discovery(&self) -> Result { + { + let cached = self.discovery.read().await; + if let Some(ref disc) = *cached { + return Ok(OidcDiscovery { + token_endpoint: disc.token_endpoint.clone(), + }); + } + } + + let mut cache = self.discovery.write().await; + // Double-check + if let Some(ref disc) = *cache { + return Ok(OidcDiscovery { + token_endpoint: disc.token_endpoint.clone(), + }); + } + + let discovery_url = format!( + "{}/.well-known/openid-configuration", + self.config.issuer_url.trim_end_matches('/') + ); + + debug!("Fetching OIDC discovery from {}", discovery_url); + + let resp = self + .http_client + .get(&discovery_url) + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Failed to fetch OIDC discovery document: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "OIDC discovery failed with status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse OIDC discovery document: {e}"), + })?; + + let result = OidcDiscovery { + token_endpoint: disc.token_endpoint.clone(), + }; + + *cache = Some(disc); + Ok(result) + } + + fn get_token_endpoint_from_issuer(&self) -> String { + // Derive Azure v2.0 token endpoint from issuer URL + // issuer: https://login.microsoftonline.com/{tenant}/v2.0 + // token: https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token + let base = self.config.issuer_url.trim_end_matches("/v2.0"); + format!("{base}/oauth2/v2.0/token") + } + + async fn get_token_endpoint(&self) -> Result { + match self.get_discovery().await { + Ok(disc) => Ok(disc.token_endpoint), + Err(e) => { + warn!("OIDC discovery failed, falling back to derived endpoint: {e}"); + Ok(self.get_token_endpoint_from_issuer()) + } + } + } + + fn scopes_string(&self) -> String { + self.config.scopes.join(" ") + } + + fn managed_identity_resource(&self) -> Result { + let [scope] = self.config.scopes.as_slice() else { + return Err(Error::InvalidInput { + message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource" + .to_string(), + }); + }; + + Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string()) + } + + // -- Client Credentials Flow -- + + async fn acquire_client_credentials(&self) -> Result { + let client_secret = self + .config + .client_secret + .as_ref() + .ok_or(Error::InvalidInput { + message: "client_secret is required for ClientCredentials flow".to_string(), + })?; + + let token_endpoint = self.get_token_endpoint().await?; + + let params = [ + ("grant_type", "client_credentials"), + ("client_id", &self.config.client_id), + ("client_secret", client_secret), + ("scope", &self.scopes_string()), + ]; + + self.post_token_request(&token_endpoint, ¶ms).await + } + + // -- Azure Managed Identity Flow -- + + async fn acquire_managed_identity(&self, mi_client_id: Option<&str>) -> Result { + let resource = self.managed_identity_resource()?; + + let mut url = format!( + "{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}", + urlencoding::encode(&resource), + ); + if let Some(cid) = mi_client_id { + url.push_str(&format!("&client_id={}", urlencoding::encode(cid))); + } + + let resp = self + .http_client + .get(&url) + .header("Metadata", "true") + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Azure IMDS request failed: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "Azure IMDS returned status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse IMDS token response: {e}"), + }) + } + + // -- Shared Helpers -- + + async fn post_token_request( + &self, + endpoint: &str, + params: &[(&str, &str)], + ) -> Result { + let resp = self + .http_client + .post(endpoint) + .form(params) + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Token request to {endpoint} failed: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "Token request failed with status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse token response: {e}"), + }) + } +} + +#[async_trait::async_trait] +impl HeaderProvider for OAuthHeaderProvider { + async fn get_headers(&self) -> Result> { + let token = self.get_valid_token().await?; + Ok(HashMap::from([( + "authorization".to_string(), + format!("Bearer {token}"), + )])) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + use tokio::task::JoinHandle; + + #[test] + fn test_token_state_expiry() { + let mut state = TokenState::new(); + assert!(state.is_expired(Duration::from_secs(0))); + + state.access_token = Some("tok".to_string()); + state.expires_at = Some(Instant::now() + Duration::from_secs(600)); + assert!(!state.is_expired(Duration::from_secs(300))); + assert!(state.is_expired(Duration::from_secs(601))); + + state.expires_at = None; + assert!(state.is_expired(Duration::from_secs(0))); + } + + #[test] + fn test_token_state_uses_default_expiry() { + let mut state = TokenState::new(); + let response = TokenResponse { + access_token: "tok".to_string(), + expires_in: None, + token_type: None, + }; + + state.update(&response); + + assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1))); + assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1))); + } + + #[test] + fn test_scopes_string() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: Some("secret".to_string()), + scopes: vec!["scope1".to_string(), "scope2".to_string()], + flow: OAuthFlow::ClientCredentials, + refresh_buffer_secs: None, + }; + let provider = OAuthHeaderProvider::new(config).unwrap(); + assert_eq!(provider.scopes_string(), "scope1 scope2"); + } + + #[test] + fn test_managed_identity_resource_from_default_scope() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: None, + scopes: vec!["api://test/.default".to_string()], + flow: OAuthFlow::AzureManagedIdentity { client_id: None }, + refresh_buffer_secs: None, + }; + let provider = OAuthHeaderProvider::new(config).unwrap(); + assert_eq!(provider.managed_identity_resource().unwrap(), "api://test"); + } + + #[test] + fn test_managed_identity_resource_without_default_suffix() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: None, + scopes: vec!["api://test".to_string()], + flow: OAuthFlow::AzureManagedIdentity { client_id: None }, + refresh_buffer_secs: None, + }; + let provider = OAuthHeaderProvider::new(config).unwrap(); + assert_eq!(provider.managed_identity_resource().unwrap(), "api://test"); + } + + #[test] + fn test_managed_identity_rejects_multiple_scopes() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: None, + scopes: vec![ + "api://test-a/.default".to_string(), + "api://test-b/.default".to_string(), + ], + flow: OAuthFlow::AzureManagedIdentity { client_id: None }, + refresh_buffer_secs: None, + }; + assert!(OAuthHeaderProvider::new(config).is_err()); + } + + #[test] + fn test_token_endpoint_derivation() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/my-tenant/v2.0".to_string(), + client_id: "id".to_string(), + client_secret: None, + scopes: vec!["api://test/.default".to_string()], + flow: OAuthFlow::AzureManagedIdentity { client_id: None }, + refresh_buffer_secs: None, + }; + let provider = OAuthHeaderProvider::new(config).unwrap(); + assert_eq!( + provider.get_token_endpoint_from_issuer(), + "https://login.microsoftonline.com/my-tenant/oauth2/v2.0/token" + ); + } + + #[test] + fn test_client_credentials_requires_secret() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: None, + scopes: vec!["scope".to_string()], + flow: OAuthFlow::ClientCredentials, + refresh_buffer_secs: None, + }; + assert!(OAuthHeaderProvider::new(config).is_err()); + } + + #[test] + fn test_empty_scopes_rejected() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: None, + scopes: vec![], + flow: OAuthFlow::AzureManagedIdentity { client_id: None }, + refresh_buffer_secs: None, + }; + assert!(OAuthHeaderProvider::new(config).is_err()); + } + + #[tokio::test] + async fn test_client_credentials_token_lifecycle() { + let (issuer_url, token_requests, server) = spawn_oauth_server().await; + let config = OAuthConfig { + issuer_url, + client_id: "client-id".to_string(), + client_secret: Some("secret".to_string()), + scopes: vec!["scope".to_string()], + flow: OAuthFlow::ClientCredentials, + refresh_buffer_secs: Some(0), + }; + let provider = OAuthHeaderProvider::new(config).unwrap(); + + let headers = provider.get_headers().await.unwrap(); + assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1"); + assert_eq!(token_requests.load(Ordering::SeqCst), 1); + + let headers = provider.get_headers().await.unwrap(); + assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1"); + assert_eq!(token_requests.load(Ordering::SeqCst), 1); + + provider.token_state.write().await.expires_at = + Some(Instant::now() - Duration::from_secs(1)); + + let headers = provider.get_headers().await.unwrap(); + assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2"); + assert_eq!(token_requests.load(Ordering::SeqCst), 2); + + server.await.unwrap(); + } + + async fn spawn_oauth_server() -> (String, Arc, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let issuer_url = format!("http://{addr}"); + let token_requests = Arc::new(AtomicUsize::new(0)); + let server_token_requests = Arc::clone(&token_requests); + + let server = tokio::spawn(async move { + for _ in 0..3 { + let (mut stream, _) = listener.accept().await.unwrap(); + let (request_line, body) = read_http_request(&mut stream).await; + + if request_line.starts_with("GET /.well-known/openid-configuration ") { + let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#); + write_json_response(&mut stream, "200 OK", &discovery).await; + } else if request_line.starts_with("POST /token ") { + assert!(body.contains("grant_type=client_credentials")); + assert!(body.contains("client_id=client-id")); + assert!(body.contains("client_secret=secret")); + assert!(body.contains("scope=scope")); + + let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1; + let token = format!( + r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"# + ); + write_json_response(&mut stream, "200 OK", &token).await; + } else { + write_json_response(&mut stream, "404 Not Found", "{}").await; + } + } + }); + + (issuer_url, token_requests, server) + } + + async fn read_http_request(stream: &mut TcpStream) -> (String, String) { + let mut buffer = Vec::new(); + let mut header_end = None; + + while header_end.is_none() { + let mut chunk = [0; 1024]; + let read = stream.read(&mut chunk).await.unwrap(); + assert_ne!(read, 0, "connection closed before request headers"); + buffer.extend_from_slice(&chunk[..read]); + header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4); + } + + let header_end = header_end.unwrap(); + let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string(); + let request_line = headers.lines().next().unwrap_or_default().to_string(); + let content_length = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + + while buffer.len() < header_end + content_length { + let mut chunk = [0; 1024]; + let read = stream.read(&mut chunk).await.unwrap(); + assert_ne!(read, 0, "connection closed before request body"); + buffer.extend_from_slice(&chunk[..read]); + } + + let body = + String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string(); + + (request_line, body) + } + + fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option { + haystack + .windows(needle.len()) + .position(|window| window == needle) + } + + async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) { + let response = format!( + "HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ); + stream.write_all(response.as_bytes()).await.unwrap(); + } +}