diff --git a/Cargo.lock b/Cargo.lock index 6c6da7cb6..187f3c72c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,7 +108,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -119,7 +119,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -1648,7 +1648,7 @@ version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2826,7 +2826,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3087,7 +3087,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3996,7 +3996,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.3", "system-configuration", "tokio", "tower-service", @@ -4233,6 +4233,25 @@ dependencies = [ "serde", ] +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -4323,7 +4342,7 @@ dependencies = [ "portable-atomic-util", "serde_core", "wasm-bindgen", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -5025,6 +5044,7 @@ dependencies = [ "aws-sdk-kms", "aws-sdk-s3", "aws-smithy-runtime", + "base64 0.22.1", "bytes", "candle-core", "candle-nn", @@ -5063,6 +5083,7 @@ dependencies = [ "moka", "num-traits", "object_store", + "open", "pin-project", "polars", "polars-arrow", @@ -5075,12 +5096,14 @@ dependencies = [ "serde", "serde_json", "serde_with", + "sha2", "snafu 0.8.9", "tempfile", "test-log", "tokenizers", "tokio", "url", + "urlencoding", "uuid", "walkdir", ] @@ -5804,7 +5827,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -6029,6 +6052,17 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "open" +version = "5.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fbaa89d2ddc8473c78a3adf69eea8cffa28c483b8e02a971ef31527cd0fc92c" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "opendal" version = "0.56.0" @@ -6365,6 +6399,12 @@ dependencies = [ "stfu8", ] +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "pbkdf2" version = "0.12.2" @@ -7046,8 +7086,8 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ - "heck 0.4.1", - "itertools 0.11.0", + "heck 0.5.0", + "itertools 0.14.0", "log", "multimap", "petgraph", @@ -7066,7 +7106,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.117", @@ -7231,7 +7271,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls 0.23.37", - "socket2 0.5.10", + "socket2 0.6.3", "thiserror 2.0.18", "tokio", "tracing", @@ -7269,7 +7309,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.3", "tracing", "windows-sys 0.60.2", ] @@ -7970,7 +8010,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -8041,7 +8081,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -8549,7 +8589,7 @@ version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.117", @@ -8561,7 +8601,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.117", @@ -8584,7 +8624,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -8964,7 +9004,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -9893,7 +9933,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index 988f35dc6..fc87fe381 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -12,6 +12,7 @@ ## Enumerations - [FullTextQueryType](enumerations/FullTextQueryType.md) +- [OAuthFlowType](enumerations/OAuthFlowType.md) - [Occur](enumerations/Occur.md) - [Operator](enumerations/Operator.md) @@ -70,6 +71,8 @@ - [IvfPqOptions](interfaces/IvfPqOptions.md) - [IvfRqOptions](interfaces/IvfRqOptions.md) - [MergeResult](interfaces/MergeResult.md) +- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md) +- [OAuthConfig](interfaces/OAuthConfig.md) - [OpenTableOptions](interfaces/OpenTableOptions.md) - [OptimizeOptions](interfaces/OptimizeOptions.md) - [OptimizeStats](interfaces/OptimizeStats.md) diff --git a/docs/src/js/interfaces/ConnectionOptions.md b/docs/src/js/interfaces/ConnectionOptions.md index 1ad0e127a..77cd2f812 100644 --- a/docs/src/js/interfaces/ConnectionOptions.md +++ b/docs/src/js/interfaces/ConnectionOptions.md @@ -64,6 +64,18 @@ client used by manifest-enabled native connections. *** +### oauthConfig? + +```ts +optional oauthConfig: NativeOAuthConfig; +``` + +(For LanceDB cloud only): OAuth configuration for IdP-based +authentication (e.g., Azure Entra ID). When set, token acquisition +and refresh are handled entirely in Rust. + +*** + ### readConsistencyInterval? ```ts diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 648af58ef..6c3a7d3bb 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -48,6 +48,7 @@ export { SplitHashOptions, SplitSequentialOptions, ShuffleOptions, + OAuthConfig as NativeOAuthConfig, } from "./native.js"; export { @@ -113,6 +114,8 @@ export { TokenResponse, } from "./header"; +export { OAuthConfig, OAuthFlowType } from "./oauth"; + export { MergeInsertBuilder, WriteExecutionOptions } from "./merge"; export * as embedding from "./embedding"; diff --git a/nodejs/lancedb/oauth.ts b/nodejs/lancedb/oauth.ts new file mode 100644 index 000000000..ac0cffd76 --- /dev/null +++ b/nodejs/lancedb/oauth.ts @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +/** + * OAuth authentication flow types. + */ +export enum OAuthFlowType { + /** Client Credentials grant (service-to-service / M2M). */ + ClientCredentials = "client_credentials", + /** Authorization Code with PKCE (interactive browser-based auth). */ + AuthorizationCodePKCE = "authorization_code_pkce", + /** Device Code grant (CLI / headless environments). */ + DeviceCode = "device_code", + /** Azure Managed Identity via IMDS. */ + AzureManagedIdentity = "azure_managed_identity", + /** Workload Identity Federation (K8s, GitHub Actions). */ + WorkloadIdentity = "workload_identity", +} + +/** + * OAuth configuration for LanceDB authentication. + * + * All token acquisition and refresh is handled in the Rust layer. + * This config is passed through to Rust via napi-rs. + * + * @example Client Credentials (service-to-service): + * ```typescript + * const config: OAuthConfig = { + * issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0", + * clientId: "app-id", + * clientSecret: "secret", + * scopes: ["api://lancedb-api/.default"], + * }; + * ``` + * + * @example Azure Managed Identity: + * ```typescript + * const config: OAuthConfig = { + * issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0", + * clientId: "app-id", + * scopes: ["api://lancedb-api/.default"], + * flow: OAuthFlowType.AzureManagedIdentity, + * }; + * ``` + */ +export interface OAuthConfig { + /** + * OIDC issuer URL or OAuth authority URL. + * For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + */ + issuerUrl: string; + + /** Application / Client ID. */ + clientId: string; + + /** + * OAuth scopes to request. + * For Azure: `["api://{app_id}/.default"]` + */ + scopes: string[]; + + /** Authentication flow (default: ClientCredentials). */ + flow?: OAuthFlowType; + + /** Client secret (required for ClientCredentials). */ + clientSecret?: string; + + /** Redirect URI (AuthorizationCodePKCE flow). */ + redirectUri?: string; + + /** Port for local callback server (AuthorizationCodePKCE, default: 8400). */ + callbackPort?: number; + + /** Client ID for user-assigned managed identity (AzureManagedIdentity). */ + managedIdentityClientId?: string; + + /** Path to federated token file (WorkloadIdentity). */ + tokenFile?: string; + + /** Seconds before expiry to trigger proactive refresh (default: 300). */ + refreshBufferSecs?: number; +} diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index 09be9465f..ddcede873 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -85,6 +85,11 @@ impl Connection { builder = builder.client_config(rust_config); + if let Some(oauth_config) = options.oauth_config { + let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into(); + builder = builder.oauth_config(config); + } + if let Some(api_key) = options.api_key { builder = builder.api_key(&api_key); } diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index 87bc97ce7..71352257c 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -60,6 +60,10 @@ pub struct ConnectionOptions { /// (For LanceDB cloud only): the host to use for LanceDB cloud. Used /// for testing purposes. pub host_override: Option, + /// (For LanceDB cloud only): OAuth configuration for IdP-based + /// authentication (e.g., Azure Entra ID). When set, token acquisition + /// and refresh are handled entirely in Rust. + pub oauth_config: Option, } #[napi(object)] diff --git a/nodejs/src/remote.rs b/nodejs/src/remote.rs index 8cfcbc984..4a15546db 100644 --- a/nodejs/src/remote.rs +++ b/nodejs/src/remote.rs @@ -140,6 +140,67 @@ impl From for lancedb::remote::TlsConfig { } } +/// OAuth configuration for LanceDB authentication. +/// All token acquisition and refresh is handled in the Rust layer. +#[napi(object)] +#[derive(Debug, Clone)] +pub struct OAuthConfig { + /// OIDC issuer URL or OAuth authority URL. + /// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + pub issuer_url: String, + /// Application / Client ID. + pub client_id: String, + /// OAuth scopes to request. For Azure: `["api://{app_id}/.default"]` + pub scopes: Vec, + /// Authentication flow: "client_credentials", "authorization_code_pkce", + /// "device_code", "azure_managed_identity", "workload_identity" + pub flow: Option, + /// Client secret (required for client_credentials). + pub client_secret: Option, + /// Redirect URI (authorization_code_pkce flow). + pub redirect_uri: Option, + /// Port for local callback server (authorization_code_pkce, default: 8400). + pub callback_port: Option, + /// Client ID for user-assigned managed identity (azure_managed_identity). + pub managed_identity_client_id: Option, + /// Path to federated token file (workload_identity). + pub token_file: Option, + /// Seconds before expiry to trigger proactive refresh (default: 300). + pub refresh_buffer_secs: Option, +} + +impl From for lancedb::remote::oauth::OAuthConfig { + fn from(config: OAuthConfig) -> Self { + use lancedb::remote::oauth::OAuthFlow; + + let flow = match config.flow.as_deref().unwrap_or("client_credentials") { + "authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE { + redirect_uri: config.redirect_uri, + callback_port: config.callback_port, + }, + "device_code" => OAuthFlow::DeviceCode, + "azure_managed_identity" => OAuthFlow::AzureManagedIdentity { + client_id: config.managed_identity_client_id, + }, + "workload_identity" => OAuthFlow::WorkloadIdentity { + token_file: config + .token_file + .expect("tokenFile is required for workload_identity flow"), + }, + other => panic!("Unknown OAuth flow type: {other}"), + }; + + Self { + issuer_url: config.issuer_url, + client_id: config.client_id, + client_secret: config.client_secret, + scopes: config.scopes, + flow, + refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64), + } + } +} + impl From for lancedb::remote::ClientConfig { fn from(config: ClientConfig) -> Self { Self { diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index efeed258f..89399b16f 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -320,6 +320,7 @@ async def connect_async( session: Optional[Session] = None, manifest_enabled: bool = False, namespace_client_properties: Optional[Dict[str, str]] = None, + oauth_config=None, ) -> AsyncConnection: """Connect to a LanceDB database. @@ -410,6 +411,7 @@ async def connect_async( session, manifest_enabled, namespace_client_properties, + oauth_config, ) ) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 8811723e2..5e9efabe1 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -247,6 +247,7 @@ async def connect( session: Optional[Session], manifest_enabled: bool = False, namespace_client_properties: Optional[Dict[str, str]] = None, + oauth_config: Optional[Any] = None, ) -> Connection: ... class RecordBatchStream: diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 289e28942..a6ef55eb5 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -9,6 +9,7 @@ from typing import List, Optional from lancedb import __version__ from .header import HeaderProvider +from .oauth import OAuthConfig, OAuthFlowType __all__ = [ "TimeoutConfig", @@ -16,6 +17,8 @@ __all__ = [ "TlsConfig", "ClientConfig", "HeaderProvider", + "OAuthConfig", + "OAuthFlowType", ] diff --git a/python/python/lancedb/remote/oauth.py b/python/python/lancedb/remote/oauth.py new file mode 100644 index 000000000..4b588c16c --- /dev/null +++ b/python/python/lancedb/remote/oauth.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + + +class OAuthFlowType(str, Enum): + """OAuth authentication flow types.""" + + CLIENT_CREDENTIALS = "client_credentials" + """Client Credentials grant (service-to-service / M2M).""" + + AUTHORIZATION_CODE_PKCE = "authorization_code_pkce" + """Authorization Code with PKCE (interactive browser-based auth).""" + + DEVICE_CODE = "device_code" + """Device Code grant (CLI / headless environments).""" + + AZURE_MANAGED_IDENTITY = "azure_managed_identity" + """Azure Managed Identity via IMDS.""" + + WORKLOAD_IDENTITY = "workload_identity" + """Workload Identity Federation (K8s, GitHub Actions).""" + + +@dataclass +class OAuthConfig: + """OAuth configuration for LanceDB authentication. + + All token acquisition and refresh is handled in the Rust layer. + This config is passed through to Rust via PyO3. + + Parameters + ---------- + issuer_url : str + OIDC issuer URL or OAuth authority URL. + For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0`` + client_id : str + Application / Client ID. + scopes : List[str] + OAuth scopes to request. + For Azure: ``["api://{app_id}/.default"]`` + flow : OAuthFlowType + Authentication flow to use. Default: CLIENT_CREDENTIALS. + client_secret : Optional[str] + Client secret (required for CLIENT_CREDENTIALS). + redirect_uri : Optional[str] + Redirect URI for AUTHORIZATION_CODE_PKCE flow. + callback_port : Optional[int] + Port for local HTTP callback server (AUTHORIZATION_CODE_PKCE, default: 8400). + managed_identity_client_id : Optional[str] + Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY). + token_file : Optional[str] + Path to federated token file (WORKLOAD_IDENTITY). + refresh_buffer_secs : Optional[int] + Seconds before expiry to trigger proactive refresh (default: 300). + + Examples + -------- + Client Credentials (service-to-service): + + >>> config = OAuthConfig( + ... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0", + ... client_id="app-id", + ... client_secret="secret", + ... scopes=["api://lancedb-api/.default"], + ... ) + + Azure Managed Identity: + + >>> config = OAuthConfig( + ... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0", + ... client_id="app-id", + ... scopes=["api://lancedb-api/.default"], + ... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY, + ... ) + """ + + issuer_url: str + client_id: str + scopes: List[str] + flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS + client_secret: Optional[str] = None + redirect_uri: Optional[str] = None + callback_port: Optional[int] = None + managed_identity_client_id: Optional[str] = None + token_file: Optional[str] = None + refresh_buffer_secs: Optional[int] = None diff --git a/python/src/connection.rs b/python/src/connection.rs index 703b44424..8070c2d14 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -524,7 +524,7 @@ impl Connection { } #[pyfunction] -#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))] +#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))] #[allow(clippy::too_many_arguments)] pub fn connect( py: Python<'_>, @@ -538,6 +538,7 @@ pub fn connect( session: Option, manifest_enabled: bool, namespace_client_properties: Option>, + oauth_config: Option, ) -> PyResult> { future_into_py(py, async move { let mut builder = lancedb::connect(&uri); @@ -567,6 +568,10 @@ pub fn connect( if let Some(client_config) = client_config { builder = builder.client_config(client_config.into()); } + if let Some(oauth_config) = oauth_config { + let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into(); + builder = builder.oauth_config(config); + } if let Some(session) = session { builder = builder.session(session.inner.clone()); } diff --git a/python/src/lib.rs b/python/src/lib.rs index d0e933dba..e764a9ff1 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -26,6 +26,7 @@ pub mod expr; pub mod header; pub mod index; pub mod namespace; +pub mod oauth; pub mod permutation; pub mod query; pub mod runtime; diff --git a/python/src/oauth.rs b/python/src/oauth.rs new file mode 100644 index 000000000..ad11f6951 --- /dev/null +++ b/python/src/oauth.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use pyo3::FromPyObject; + +use lancedb::remote::oauth::{OAuthConfig, OAuthFlow}; + +/// Python-side OAuth configuration, extracted via FromPyObject. +/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass. +#[derive(FromPyObject)] +pub struct PyOAuthConfig { + pub issuer_url: String, + pub client_id: String, + pub scopes: Vec, + pub flow: String, + pub client_secret: Option, + pub redirect_uri: Option, + pub callback_port: Option, + pub managed_identity_client_id: Option, + pub token_file: Option, + pub refresh_buffer_secs: Option, +} + +impl From for OAuthConfig { + fn from(py: PyOAuthConfig) -> Self { + let flow = match py.flow.as_str() { + "client_credentials" => OAuthFlow::ClientCredentials, + "authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE { + redirect_uri: py.redirect_uri, + callback_port: py.callback_port, + }, + "device_code" => OAuthFlow::DeviceCode, + "azure_managed_identity" => OAuthFlow::AzureManagedIdentity { + client_id: py.managed_identity_client_id, + }, + "workload_identity" => OAuthFlow::WorkloadIdentity { + token_file: py + .token_file + .expect("token_file is required for workload_identity flow"), + }, + other => panic!("Unknown OAuth flow type: {other}"), + }; + + OAuthConfig { + issuer_url: py.issuer_url, + client_id: py.client_id, + client_secret: py.client_secret, + scopes: py.scopes, + flow, + refresh_buffer_secs: py.refresh_buffer_secs, + } + } +} diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index b05302b9b..8caf85b96 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -75,6 +75,11 @@ reqwest = { version = "0.12.0", default-features = false, features = [ "stream", ], optional = true } http = { version = "1", optional = true } # Matching what is in reqwest +# OAuth dependencies (used by remote feature) +sha2 = { version = "0.10", optional = true } +base64 = { version = "0.22", optional = true } +urlencoding = { version = "2", optional = true } +open = { version = "5", optional = true } uuid = { version = "1.7.0", features = ["v4"] } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } @@ -128,7 +133,7 @@ huggingface = [ "lance-namespace-impls/dir-huggingface", ] dynamodb = ["lance/dynamodb", "aws"] -remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"] +remote = ["dep:reqwest", "dep:http", "dep:sha2", "dep:base64", "dep:urlencoding", "dep:open", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] bedrock = ["dep:aws-sdk-bedrockruntime"] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 8034c2a53..a0f18d8a7 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -622,6 +622,8 @@ pub struct ConnectRequest { pub struct ConnectBuilder { request: ConnectRequest, embedding_registry: Option>, + #[cfg(feature = "remote")] + oauth_config: Option, } #[cfg(feature = "remote")] @@ -643,6 +645,8 @@ impl ConnectBuilder { session: None, }, embedding_registry: None, + #[cfg(feature = "remote")] + oauth_config: None, } } @@ -731,6 +735,19 @@ impl ConnectBuilder { self } + /// Configure OAuth authentication for LanceDB Cloud/Enterprise. + /// + /// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider) + /// from the given config and sets it as the header provider, replacing any + /// previously configured header provider or API key. + /// + /// Token acquisition and refresh are handled entirely in Rust. + #[cfg(feature = "remote")] + pub fn oauth_config(mut self, config: crate::remote::oauth::OAuthConfig) -> Self { + self.oauth_config = Some(config); + self + } + /// Provide a custom [`EmbeddingRegistry`] to use for this connection. pub fn embedding_registry(mut self, registry: Arc) -> Self { self.embedding_registry = Some(registry); @@ -874,9 +891,29 @@ impl ConnectBuilder { let region = options.region.ok_or_else(|| Error::InvalidInput { message: "A region is required when connecting to LanceDb Cloud".to_string(), })?; - let api_key = options.api_key.ok_or_else(|| Error::InvalidInput { - message: "An api_key is required when connecting to LanceDb Cloud".to_string(), - })?; + + // When OAuth is configured, api_key is not required + let api_key = match (&self.oauth_config, &options.api_key) { + (Some(_), None) => String::new(), + (Some(_), Some(key)) => key.clone(), + (None, Some(key)) => key.clone(), + (None, None) => { + return Err(Error::InvalidInput { + message: + "An api_key or oauth_config is required when connecting to LanceDb Cloud" + .to_string(), + }); + } + }; + + let mut client_config = self.request.client_config; + + // Apply OAuth header provider if configured + if let Some(oauth_config) = self.oauth_config { + let provider = crate::remote::oauth::OAuthHeaderProvider::new(oauth_config)?; + client_config.header_provider = + Some(Arc::new(provider) as Arc); + } let storage_options = StorageOptions(options.storage_options.clone()); let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( @@ -884,7 +921,7 @@ impl ConnectBuilder { &api_key, ®ion, options.host_override, - self.request.client_config, + client_config, storage_options.into(), )?); Ok(Connection { diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index 866ecdcd8..25e19c537 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -8,6 +8,7 @@ pub(crate) mod client; pub(crate) mod db; +pub mod oauth; mod retry; pub(crate) mod table; pub(crate) mod util; @@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json"; pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig}; pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder}; +pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider}; diff --git a/rust/lancedb/src/remote/oauth.rs b/rust/lancedb/src/remote/oauth.rs new file mode 100644 index 000000000..5150f1b93 --- /dev/null +++ b/rust/lancedb/src/remote/oauth.rs @@ -0,0 +1,906 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use log::{debug, info, warn}; +use reqwest::Client; +use serde::Deserialize; +use tokio::sync::RwLock; + +use crate::error::{Error, Result}; +use crate::remote::client::HeaderProvider; + +const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300; +const DEFAULT_CALLBACK_PORT: u16 = 8400; +const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token"; +const AZURE_IMDS_API_VERSION: &str = "2018-02-01"; + +/// OAuth authentication flow configuration. +#[derive(Debug, Clone)] +pub enum OAuthFlow { + /// Client Credentials grant (service-to-service / M2M). + /// Requires `client_secret` in [`OAuthConfig`]. + ClientCredentials, + + /// Authorization Code with PKCE (interactive browser-based auth). + AuthorizationCodePKCE { + /// Redirect URI (default: `http://localhost:{callback_port}/callback`) + redirect_uri: Option, + /// Port for the local HTTP callback server (default: 8400) + callback_port: Option, + }, + + /// Device Code grant (CLI / headless environments). + /// Displays a verification URL and user code for out-of-band authentication. + DeviceCode, + + /// Azure Managed Identity via IMDS. + /// Works on Azure VMs, AKS, App Service, and Azure Functions. + AzureManagedIdentity { + /// Client ID for user-assigned managed identity. + /// Omit for system-assigned managed identity. + client_id: Option, + }, + + /// Workload Identity Federation. + /// Exchanges a platform token (K8s service account, GitHub OIDC) for an IdP token. + WorkloadIdentity { + /// Path to the federated token file + /// (e.g. `AZURE_FEDERATED_TOKEN_FILE`). + token_file: String, + }, +} + +/// OAuth configuration for LanceDB authentication. +/// +/// All token acquisition and refresh is handled in the Rust layer. +/// Python and TypeScript bindings expose this as a plain config object. +#[derive(Debug, Clone)] +pub struct OAuthConfig { + /// OIDC issuer URL or OAuth authority URL. + /// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0` + pub issuer_url: String, + + /// Application / Client ID. + pub client_id: String, + + /// Client secret (required for `ClientCredentials`, optional for others). + pub client_secret: Option, + + /// OAuth scopes to request. + /// For Azure: `["api://{app_id}/.default"]` + pub scopes: Vec, + + /// Authentication flow to use. + pub flow: OAuthFlow, + + /// Seconds before token expiry to trigger proactive refresh (default: 300). + pub refresh_buffer_secs: Option, +} + +// -- OIDC Discovery -- + +#[derive(Debug, Deserialize)] +struct OidcDiscovery { + token_endpoint: String, + authorization_endpoint: Option, + device_authorization_endpoint: Option, +} + +// -- Token Response -- + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + /// Token lifetime in seconds. + /// Some providers (Azure IMDS) return this as a string, so we accept both. + #[serde(default, deserialize_with = "deserialize_optional_u64_or_string")] + expires_in: Option, + #[serde(default)] + #[allow(dead_code)] + token_type: Option, +} + +fn deserialize_optional_u64_or_string<'de, D>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::de; + + struct U64OrString; + impl<'de> de::Visitor<'de> for U64OrString { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a u64, a numeric string, or null") + } + + fn visit_u64(self, v: u64) -> std::result::Result { + Ok(Some(v)) + } + + fn visit_i64(self, v: i64) -> std::result::Result { + Ok(Some(v as u64)) + } + + fn visit_str(self, v: &str) -> std::result::Result { + v.parse::().map(Some).map_err(de::Error::custom) + } + + fn visit_none(self) -> std::result::Result { + Ok(None) + } + + fn visit_unit(self) -> std::result::Result { + Ok(None) + } + } + + deserializer.deserialize_any(U64OrString) +} + +// -- Device Code Response -- + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default)] + verification_uri_complete: Option, + expires_in: u64, + interval: Option, +} + +// -- Internal Token State -- + +struct TokenState { + access_token: Option, + refresh_token: Option, + expires_at: Option, +} + +impl TokenState { + fn new() -> Self { + Self { + access_token: None, + refresh_token: None, + expires_at: None, + } + } + + fn is_expired(&self, buffer: Duration) -> bool { + match (self.access_token.as_ref(), self.expires_at) { + (Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at, + (None, _) => true, + (Some(_), None) => false, // no expiry info, assume valid + } + } + + fn update(&mut self, resp: &TokenResponse) { + self.access_token = Some(resp.access_token.clone()); + if resp.refresh_token.is_some() { + self.refresh_token = resp.refresh_token.clone(); + } + self.expires_at = resp + .expires_in + .map(|secs| Instant::now() + Duration::from_secs(secs)); + } +} + +/// OAuth header provider that manages the full token lifecycle. +/// +/// Implements [`HeaderProvider`] to inject `Authorization: Bearer ` +/// headers into every LanceDB request, with automatic token refresh. +pub struct OAuthHeaderProvider { + config: OAuthConfig, + http_client: Client, + token_state: Arc>, + /// Cached OIDC discovery document + discovery: Arc>>, + refresh_buffer: Duration, +} + +impl std::fmt::Debug for OAuthHeaderProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OAuthHeaderProvider") + .field("issuer_url", &self.config.issuer_url) + .field("client_id", &self.config.client_id) + .field("flow", &self.config.flow) + .finish() + } +} + +impl OAuthHeaderProvider { + /// Create a new OAuth header provider from configuration. + pub fn new(config: OAuthConfig) -> Result { + // Validate config upfront + if matches!(config.flow, OAuthFlow::ClientCredentials) && config.client_secret.is_none() { + return Err(Error::InvalidInput { + message: "client_secret is required for ClientCredentials flow".to_string(), + }); + } + if config.scopes.is_empty() { + return Err(Error::InvalidInput { + message: "At least one OAuth scope is required".to_string(), + }); + } + + let http_client = Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| Error::Runtime { + message: format!("Failed to create HTTP client for OAuth: {e}"), + })?; + + let refresh_buffer = Duration::from_secs( + config + .refresh_buffer_secs + .unwrap_or(DEFAULT_REFRESH_BUFFER_SECS), + ); + + Ok(Self { + config, + http_client, + token_state: Arc::new(RwLock::new(TokenState::new())), + discovery: Arc::new(RwLock::new(None)), + refresh_buffer, + }) + } + + /// Get a valid access token, refreshing if necessary. + async fn get_valid_token(&self) -> Result { + // Fast path: check if current token is still valid + { + let state = self.token_state.read().await; + if !state.is_expired(self.refresh_buffer) + && let Some(ref token) = state.access_token + { + return Ok(token.clone()); + } + } + + // Slow path: acquire or refresh token + let mut state = self.token_state.write().await; + + // Double-check after acquiring write lock + if !state.is_expired(self.refresh_buffer) + && let Some(ref token) = state.access_token + { + return Ok(token.clone()); + } + + let uses_refresh_token = !matches!( + self.config.flow, + OAuthFlow::ClientCredentials + | OAuthFlow::AzureManagedIdentity { .. } + | OAuthFlow::WorkloadIdentity { .. } + ); + + let resp = if let Some(ref refresh_token) = state.refresh_token + && uses_refresh_token + { + debug!("Refreshing OAuth token using refresh_token"); + self.refresh_with_token(refresh_token).await? + } else { + debug!("Acquiring new OAuth token via {:?} flow", self.config.flow); + self.acquire_token().await? + }; + + state.update(&resp); + Ok(resp.access_token) + } + + /// Acquire a new token using the configured flow. + async fn acquire_token(&self) -> Result { + match &self.config.flow { + OAuthFlow::ClientCredentials => self.acquire_client_credentials().await, + OAuthFlow::AuthorizationCodePKCE { + redirect_uri, + callback_port, + } => { + self.acquire_authorization_code_pkce( + redirect_uri.as_deref(), + callback_port.unwrap_or(DEFAULT_CALLBACK_PORT), + ) + .await + } + OAuthFlow::DeviceCode => self.acquire_device_code().await, + OAuthFlow::AzureManagedIdentity { client_id } => { + self.acquire_managed_identity(client_id.as_deref()).await + } + OAuthFlow::WorkloadIdentity { token_file } => { + self.acquire_workload_identity(token_file).await + } + } + } + + // -- OIDC Discovery -- + + async fn get_discovery(&self) -> Result { + { + let cached = self.discovery.read().await; + if let Some(ref disc) = *cached { + return Ok(OidcDiscovery { + token_endpoint: disc.token_endpoint.clone(), + authorization_endpoint: disc.authorization_endpoint.clone(), + device_authorization_endpoint: disc.device_authorization_endpoint.clone(), + }); + } + } + + let mut cache = self.discovery.write().await; + // Double-check + if let Some(ref disc) = *cache { + return Ok(OidcDiscovery { + token_endpoint: disc.token_endpoint.clone(), + authorization_endpoint: disc.authorization_endpoint.clone(), + device_authorization_endpoint: disc.device_authorization_endpoint.clone(), + }); + } + + let discovery_url = format!( + "{}/.well-known/openid-configuration", + self.config.issuer_url.trim_end_matches('/') + ); + + debug!("Fetching OIDC discovery from {}", discovery_url); + + let resp = self + .http_client + .get(&discovery_url) + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Failed to fetch OIDC discovery document: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "OIDC discovery failed with status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse OIDC discovery document: {e}"), + })?; + + let result = OidcDiscovery { + token_endpoint: disc.token_endpoint.clone(), + authorization_endpoint: disc.authorization_endpoint.clone(), + device_authorization_endpoint: disc.device_authorization_endpoint.clone(), + }; + + *cache = Some(disc); + Ok(result) + } + + fn get_token_endpoint_from_issuer(&self) -> String { + // Derive Azure v2.0 token endpoint from issuer URL + // issuer: https://login.microsoftonline.com/{tenant}/v2.0 + // token: https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token + let base = self.config.issuer_url.trim_end_matches("/v2.0"); + format!("{base}/oauth2/v2.0/token") + } + + async fn get_token_endpoint(&self) -> Result { + match self.get_discovery().await { + Ok(disc) => Ok(disc.token_endpoint), + Err(e) => { + warn!("OIDC discovery failed, falling back to derived endpoint: {e}"); + Ok(self.get_token_endpoint_from_issuer()) + } + } + } + + fn scopes_string(&self) -> String { + self.config.scopes.join(" ") + } + + // -- Client Credentials Flow -- + + async fn acquire_client_credentials(&self) -> Result { + let client_secret = self + .config + .client_secret + .as_ref() + .ok_or(Error::InvalidInput { + message: "client_secret is required for ClientCredentials flow".to_string(), + })?; + + let token_endpoint = self.get_token_endpoint().await?; + + let params = [ + ("grant_type", "client_credentials"), + ("client_id", &self.config.client_id), + ("client_secret", client_secret), + ("scope", &self.scopes_string()), + ]; + + self.post_token_request(&token_endpoint, ¶ms).await + } + + // -- Authorization Code + PKCE Flow -- + + async fn acquire_authorization_code_pkce( + &self, + redirect_uri: Option<&str>, + callback_port: u16, + ) -> Result { + use rand::Rng; + use tokio::io::AsyncWriteExt; + use tokio::net::TcpListener; + + let discovery = self.get_discovery().await?; + let auth_endpoint = discovery.authorization_endpoint.ok_or(Error::Runtime { + message: "OIDC discovery did not provide authorization_endpoint".to_string(), + })?; + + let default_redirect = format!("http://localhost:{callback_port}/callback"); + let redirect = redirect_uri.unwrap_or(&default_redirect); + + // Generate PKCE code verifier and challenge (S256) + const PKCE_CHARSET: &[u8] = + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"; + let code_verifier: String = { + let mut rng = rand::rng(); + (0..128) + .map(|_| { + let idx = rng.random_range(0..PKCE_CHARSET.len()); + PKCE_CHARSET[idx] as char + }) + .collect() + }; + let code_challenge = { + use sha2::{Digest, Sha256}; + let hash = Sha256::digest(code_verifier.as_bytes()); + base64_url_encode(&hash) + }; + + let state: String = { + let mut rng = rand::rng(); + (0..32) + .map(|_| { + let idx = rng.random_range(0..16u8); + b"0123456789abcdef"[idx as usize] as char + }) + .collect() + }; + + // Build authorization URL + let auth_url = format!( + "{auth_endpoint}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={state}", + urlencoding::encode(&self.config.client_id), + urlencoding::encode(redirect), + urlencoding::encode(&self.scopes_string()), + urlencoding::encode(&code_challenge), + ); + + info!("Opening browser for OAuth login..."); + info!("If the browser doesn't open, visit: {auth_url}"); + + // Try to open browser + let _ = open::that(&auth_url); + + // Start local callback server + let listener = TcpListener::bind(format!("127.0.0.1:{callback_port}")) + .await + .map_err(|e| Error::Runtime { + message: format!("Failed to bind callback server on port {callback_port}: {e}"), + })?; + + info!("Waiting for OAuth callback on port {callback_port}..."); + + let (mut stream, _) = listener.accept().await.map_err(|e| Error::Runtime { + message: format!("Failed to accept callback connection: {e}"), + })?; + + // Read the HTTP request + let mut buf = vec![0u8; 4096]; + let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .map_err(|e| Error::Runtime { + message: format!("Failed to read callback request: {e}"), + })?; + let request_str = String::from_utf8_lossy(&buf[..n]); + + // Extract authorization code from query params + let code = extract_query_param(&request_str, "code").ok_or(Error::Runtime { + message: "No authorization code in callback".to_string(), + })?; + + let returned_state = extract_query_param(&request_str, "state"); + if returned_state.as_deref() != Some(&state) { + return Err(Error::Runtime { + message: "OAuth state mismatch — possible CSRF attack".to_string(), + }); + } + + // Send success response to browser + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n

Authentication successful!

You can close this window.

"; + let _ = stream.write_all(response.as_bytes()).await; + + // Exchange code for tokens + let token_endpoint = self.get_token_endpoint().await?; + let mut params = vec![ + ("grant_type", "authorization_code"), + ("client_id", self.config.client_id.as_str()), + ("code", &code), + ("redirect_uri", redirect), + ("code_verifier", &code_verifier), + ]; + if let Some(ref secret) = self.config.client_secret { + params.push(("client_secret", secret)); + } + + self.post_token_request(&token_endpoint, ¶ms).await + } + + // -- Device Code Flow -- + + async fn acquire_device_code(&self) -> Result { + let discovery = self.get_discovery().await?; + let device_endpoint = discovery + .device_authorization_endpoint + .ok_or(Error::Runtime { + message: "OIDC discovery did not provide device_authorization_endpoint".to_string(), + })?; + + let params = [ + ("client_id", self.config.client_id.as_str()), + ("scope", &self.scopes_string()), + ]; + + let resp = self + .http_client + .post(&device_endpoint) + .form(¶ms) + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Device code request failed: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "Device code request failed with status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + let device_resp: DeviceCodeResponse = resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse device code response: {e}"), + })?; + + // Display instructions to user + info!( + "To sign in, visit {} and enter code: {}", + device_resp.verification_uri, device_resp.user_code + ); + if let Some(ref uri) = device_resp.verification_uri_complete { + info!("Or visit: {uri}"); + } + + // Poll token endpoint + let token_endpoint = self.get_token_endpoint().await?; + let poll_interval = Duration::from_secs(device_resp.interval.unwrap_or(5)); + let deadline = Instant::now() + Duration::from_secs(device_resp.expires_in); + + loop { + if Instant::now() >= deadline { + return Err(Error::Runtime { + message: "Device code flow timed out waiting for user authentication" + .to_string(), + }); + } + + tokio::time::sleep(poll_interval).await; + + let poll_params = [ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("client_id", self.config.client_id.as_str()), + ("device_code", &device_resp.device_code), + ]; + + let poll_resp = self + .http_client + .post(&token_endpoint) + .form(&poll_params) + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Device code poll failed: {e}"), + })?; + + if poll_resp.status().is_success() { + return poll_resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse token response: {e}"), + }); + } + + // Check for pending / slow_down errors + let body = poll_resp.text().await.unwrap_or_default(); + if body.contains("authorization_pending") { + continue; + } + if body.contains("slow_down") { + tokio::time::sleep(Duration::from_secs(5)).await; + continue; + } + + return Err(Error::Runtime { + message: format!("Device code poll failed: {body}"), + }); + } + } + + // -- Azure Managed Identity Flow -- + + async fn acquire_managed_identity(&self, mi_client_id: Option<&str>) -> Result { + let resource = self.scopes_string().replace("/.default", ""); + + let mut url = format!( + "{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}", + urlencoding::encode(&resource), + ); + if let Some(cid) = mi_client_id { + url.push_str(&format!("&client_id={}", urlencoding::encode(cid))); + } + + let resp = self + .http_client + .get(&url) + .header("Metadata", "true") + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Azure IMDS request failed: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "Azure IMDS returned status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse IMDS token response: {e}"), + }) + } + + // -- Workload Identity Federation Flow -- + + async fn acquire_workload_identity(&self, token_file: &str) -> Result { + let federated_token = + tokio::fs::read_to_string(token_file) + .await + .map_err(|e| Error::Runtime { + message: format!("Failed to read federated token file '{token_file}': {e}"), + })?; + + let token_endpoint = self.get_token_endpoint().await?; + + let params = [ + ("grant_type", "client_credentials"), + ("client_id", self.config.client_id.as_str()), + ( + "client_assertion_type", + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + ), + ("client_assertion", federated_token.trim()), + ("scope", &self.scopes_string()), + ]; + + self.post_token_request(&token_endpoint, ¶ms).await + } + + // -- Refresh Token Flow -- + + async fn refresh_with_token(&self, refresh_token: &str) -> Result { + let token_endpoint = self.get_token_endpoint().await?; + + let mut params = vec![ + ("grant_type", "refresh_token"), + ("client_id", self.config.client_id.as_str()), + ("refresh_token", refresh_token), + ]; + if let Some(ref secret) = self.config.client_secret { + params.push(("client_secret", secret.as_str())); + } + + self.post_token_request(&token_endpoint, ¶ms).await + } + + // -- Shared Helpers -- + + async fn post_token_request( + &self, + endpoint: &str, + params: &[(&str, &str)], + ) -> Result { + let resp = self + .http_client + .post(endpoint) + .form(params) + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Token request to {endpoint} failed: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "Token request failed with status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse token response: {e}"), + }) + } +} + +#[async_trait::async_trait] +impl HeaderProvider for OAuthHeaderProvider { + async fn get_headers(&self) -> Result> { + let token = self.get_valid_token().await?; + Ok(HashMap::from([( + "authorization".to_string(), + format!("Bearer {token}"), + )])) + } +} + +// -- Utility functions -- + +fn base64_url_encode(input: &[u8]) -> String { + use base64::Engine; + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input) +} + +/// Extract a query parameter value from a raw HTTP GET request line. +fn extract_query_param(request: &str, param: &str) -> Option { + let first_line = request.lines().next()?; + let path = first_line.split_whitespace().nth(1)?; + let query = path.split('?').nth(1)?; + for pair in query.split('&') { + let mut kv = pair.splitn(2, '='); + if let (Some(key), Some(value)) = (kv.next(), kv.next()) + && key == param + { + return Some(urlencoding::decode(value).ok()?.into_owned()); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_query_param() { + let request = "GET /callback?code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n"; + assert_eq!( + extract_query_param(request, "code"), + Some("abc123".to_string()) + ); + assert_eq!( + extract_query_param(request, "state"), + Some("xyz".to_string()) + ); + assert_eq!(extract_query_param(request, "missing"), None); + } + + #[test] + fn test_extract_query_param_encoded() { + let request = "GET /callback?code=abc%20123&state=x%26y HTTP/1.1\r\n"; + assert_eq!( + extract_query_param(request, "code"), + Some("abc 123".to_string()) + ); + assert_eq!( + extract_query_param(request, "state"), + Some("x&y".to_string()) + ); + } + + #[test] + fn test_token_state_expiry() { + let mut state = TokenState::new(); + assert!(state.is_expired(Duration::from_secs(0))); + + state.access_token = Some("tok".to_string()); + state.expires_at = Some(Instant::now() + Duration::from_secs(600)); + assert!(!state.is_expired(Duration::from_secs(300))); + assert!(state.is_expired(Duration::from_secs(601))); + } + + #[test] + fn test_base64_url_encode() { + let input = b"hello world"; + let encoded = base64_url_encode(input); + assert!(!encoded.contains('+')); + assert!(!encoded.contains('/')); + assert!(!encoded.contains('=')); + } + + #[test] + fn test_scopes_string() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: Some("secret".to_string()), + scopes: vec!["scope1".to_string(), "scope2".to_string()], + flow: OAuthFlow::ClientCredentials, + refresh_buffer_secs: None, + }; + let provider = OAuthHeaderProvider::new(config).unwrap(); + assert_eq!(provider.scopes_string(), "scope1 scope2"); + } + + #[test] + fn test_token_endpoint_derivation() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/my-tenant/v2.0".to_string(), + client_id: "id".to_string(), + client_secret: None, + scopes: vec!["api://test/.default".to_string()], + flow: OAuthFlow::DeviceCode, + refresh_buffer_secs: None, + }; + let provider = OAuthHeaderProvider::new(config).unwrap(); + assert_eq!( + provider.get_token_endpoint_from_issuer(), + "https://login.microsoftonline.com/my-tenant/oauth2/v2.0/token" + ); + } + + #[test] + fn test_client_credentials_requires_secret() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: None, + scopes: vec!["scope".to_string()], + flow: OAuthFlow::ClientCredentials, + refresh_buffer_secs: None, + }; + assert!(OAuthHeaderProvider::new(config).is_err()); + } + + #[test] + fn test_empty_scopes_rejected() { + let config = OAuthConfig { + issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), + client_id: "app-id".to_string(), + client_secret: None, + scopes: vec![], + flow: OAuthFlow::DeviceCode, + refresh_buffer_secs: None, + }; + assert!(OAuthHeaderProvider::new(config).is_err()); + } +}