diff --git a/rust/lancedb/src/remote/oauth.rs b/rust/lancedb/src/remote/oauth.rs index b6bc659e6..fd61db919 100644 --- a/rust/lancedb/src/remote/oauth.rs +++ b/rust/lancedb/src/remote/oauth.rs @@ -6,6 +6,7 @@ use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; +use async_trait::async_trait; use log::debug; use reqwest::Client; use serde::Deserialize; @@ -84,7 +85,7 @@ impl std::fmt::Debug for OAuthConfig { // -- OIDC Discovery -- -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] struct OidcDiscovery { token_endpoint: String, } @@ -193,81 +194,62 @@ impl TokenState { } } -/// OAuth header provider that manages the full token lifecycle. -/// -/// Implements [`HeaderProvider`] to inject `Authorization: Bearer ` -/// headers into every LanceDB request, with automatic token refresh. -pub struct OAuthHeaderProvider { - config: OAuthConfig, - http_client: Client, - token_state: Arc>, - /// Cached OIDC discovery document - discovery: Arc>>, - refresh_buffer: Duration, +#[async_trait] +trait TokenSource: Send + Sync + std::fmt::Debug { + async fn fetch_token(&self) -> Result; } -impl std::fmt::Debug for OAuthHeaderProvider { +struct ClientCredentialsSource { + issuer_url: String, + client_id: String, + client_secret: String, + scopes: Vec, + http_client: Client, + discovery: RwLock>, +} + +impl std::fmt::Debug for ClientCredentialsSource { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("OAuthHeaderProvider") - .field("issuer_url", &self.config.issuer_url) - .field("client_id", &self.config.client_id) - .field("flow", &self.config.flow) + f.debug_struct("ClientCredentialsSource") + .field("issuer_url", &self.issuer_url) + .field("client_id", &self.client_id) + .field("client_secret", &"") + .field("scopes", &self.scopes) .finish() } } -impl OAuthHeaderProvider { - /// Create a new OAuth header provider from configuration. - pub fn new(config: OAuthConfig) -> Result { - // Validate config upfront - if matches!(config.flow, OAuthFlow::ClientCredentials) && config.client_secret.is_none() { - return Err(Error::InvalidInput { - message: "client_secret is required for ClientCredentials flow".to_string(), - }); - } - if config.scopes.is_empty() { - return Err(Error::InvalidInput { - message: "At least one OAuth scope is required".to_string(), - }); - } - if matches!(config.flow, OAuthFlow::AzureManagedIdentity { .. }) && config.scopes.len() != 1 - { - return Err(Error::InvalidInput { - message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource" - .to_string(), - }); - } - Self::validate_issuer_transport(&config)?; - - let mut http_client = Client::builder().timeout(Duration::from_secs(30)); - if matches!(config.flow, OAuthFlow::AzureManagedIdentity { .. }) { - http_client = http_client.no_proxy(); - } - let http_client = http_client.build().map_err(|e| Error::Runtime { - message: format!("Failed to create HTTP client for OAuth: {e}"), +impl ClientCredentialsSource { + fn new( + issuer_url: String, + client_id: String, + client_secret: Option, + scopes: Vec, + ) -> Result { + let client_secret = client_secret.ok_or(Error::InvalidInput { + message: "client_secret is required for ClientCredentials flow".to_string(), })?; + Self::validate_issuer_transport(&issuer_url)?; - let refresh_buffer = Duration::from_secs( - config - .refresh_buffer_secs - .unwrap_or(DEFAULT_REFRESH_BUFFER_SECS), - ); + let http_client = Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| Error::Runtime { + message: format!("Failed to create HTTP client for OAuth: {e}"), + })?; Ok(Self { - config, + issuer_url, + client_id, + client_secret, + scopes, http_client, - token_state: Arc::new(RwLock::new(TokenState::new())), - discovery: Arc::new(RwLock::new(None)), - refresh_buffer, + discovery: RwLock::new(None), }) } - fn validate_issuer_transport(config: &OAuthConfig) -> Result<()> { - if !matches!(config.flow, OAuthFlow::ClientCredentials) { - return Ok(()); - } - - let issuer = url::Url::parse(&config.issuer_url).map_err(|e| Error::InvalidInput { + fn validate_issuer_transport(issuer_url: &str) -> Result<()> { + let issuer = url::Url::parse(issuer_url).map_err(|e| Error::InvalidInput { message: format!("Invalid OAuth issuer_url: {e}"), })?; @@ -294,68 +276,23 @@ impl OAuthHeaderProvider { .unwrap_or(false) } - /// Get a valid access token, refreshing if necessary. - async fn get_valid_token(&self) -> Result { - // Fast path: check if current token is still valid - { - let state = self.token_state.read().await; - if !state.is_expired(self.refresh_buffer) - && let Some(ref token) = state.access_token - { - return Ok(token.clone()); - } - } - - // Slow path: acquire or refresh token - let mut state = self.token_state.write().await; - - // Double-check after acquiring write lock - if !state.is_expired(self.refresh_buffer) - && let Some(ref token) = state.access_token - { - return Ok(token.clone()); - } - - debug!("Acquiring new OAuth token via {:?} flow", self.config.flow); - let resp = self.acquire_token().await?; - - state.update(&resp); - Ok(resp.access_token) - } - - /// Acquire a new token using the configured flow. - async fn acquire_token(&self) -> Result { - match &self.config.flow { - OAuthFlow::ClientCredentials => self.acquire_client_credentials().await, - OAuthFlow::AzureManagedIdentity { client_id } => { - self.acquire_managed_identity(client_id.as_deref()).await - } - } - } - - // -- OIDC Discovery -- - async fn get_discovery(&self) -> Result { { let cached = self.discovery.read().await; if let Some(ref disc) = *cached { - return Ok(OidcDiscovery { - token_endpoint: disc.token_endpoint.clone(), - }); + return Ok(disc.clone()); } } let mut cache = self.discovery.write().await; // Double-check if let Some(ref disc) = *cache { - return Ok(OidcDiscovery { - token_endpoint: disc.token_endpoint.clone(), - }); + return Ok(disc.clone()); } let discovery_url = format!( "{}/.well-known/openid-configuration", - self.config.issuer_url.trim_end_matches('/') + self.issuer_url.trim_end_matches('/') ); debug!("Fetching OIDC discovery from {}", discovery_url); @@ -383,9 +320,7 @@ impl OAuthHeaderProvider { message: format!("Failed to parse OIDC discovery document: {e}"), })?; - let result = OidcDiscovery { - token_endpoint: disc.token_endpoint.clone(), - }; + let result = disc.clone(); *cache = Some(disc); Ok(result) @@ -396,83 +331,9 @@ impl OAuthHeaderProvider { } fn scopes_string(&self) -> String { - self.config.scopes.join(" ") + self.scopes.join(" ") } - fn managed_identity_resource(&self) -> Result { - let [scope] = self.config.scopes.as_slice() else { - return Err(Error::InvalidInput { - message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource" - .to_string(), - }); - }; - - Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string()) - } - - // -- Client Credentials Flow -- - - async fn acquire_client_credentials(&self) -> Result { - let client_secret = self - .config - .client_secret - .as_ref() - .ok_or(Error::InvalidInput { - message: "client_secret is required for ClientCredentials flow".to_string(), - })?; - - let token_endpoint = self.get_token_endpoint().await?; - - let params = [ - ("grant_type", "client_credentials"), - ("client_id", &self.config.client_id), - ("client_secret", client_secret), - ("scope", &self.scopes_string()), - ]; - - self.post_token_request(&token_endpoint, ¶ms).await - } - - // -- Azure Managed Identity Flow -- - - async fn acquire_managed_identity(&self, mi_client_id: Option<&str>) -> Result { - let resource = self.managed_identity_resource()?; - - let mut url = format!( - "{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}", - urlencoding::encode(&resource), - ); - if let Some(cid) = mi_client_id { - url.push_str(&format!("&client_id={}", urlencoding::encode(cid))); - } - - let resp = self - .http_client - .get(&url) - .header("Metadata", "true") - .send() - .await - .map_err(|e| Error::Runtime { - message: format!("Azure IMDS request failed: {e}"), - })?; - - if !resp.status().is_success() { - return Err(Error::Runtime { - message: format!( - "Azure IMDS returned status {}: {}", - resp.status(), - resp.text().await.unwrap_or_default() - ), - }); - } - - resp.json().await.map_err(|e| Error::Runtime { - message: format!("Failed to parse IMDS token response: {e}"), - }) - } - - // -- Shared Helpers -- - async fn post_token_request( &self, endpoint: &str, @@ -504,7 +365,192 @@ impl OAuthHeaderProvider { } } -#[async_trait::async_trait] +#[async_trait] +impl TokenSource for ClientCredentialsSource { + async fn fetch_token(&self) -> Result { + let token_endpoint = self.get_token_endpoint().await?; + let scope = self.scopes_string(); + let params = [ + ("grant_type", "client_credentials"), + ("client_id", self.client_id.as_str()), + ("client_secret", self.client_secret.as_str()), + ("scope", scope.as_str()), + ]; + + self.post_token_request(&token_endpoint, ¶ms).await + } +} + +struct AzureImdsSource { + client_id: Option, + resource: String, + http_client: Client, +} + +impl std::fmt::Debug for AzureImdsSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AzureImdsSource") + .field("client_id", &self.client_id) + .field("resource", &self.resource) + .finish() + } +} + +impl AzureImdsSource { + fn new(scopes: Vec, client_id: Option) -> Result { + let resource = Self::resource_from_scopes(&scopes)?; + let http_client = Client::builder() + .timeout(Duration::from_secs(30)) + .no_proxy() + .build() + .map_err(|e| Error::Runtime { + message: format!("Failed to create HTTP client for Azure IMDS OAuth: {e}"), + })?; + + Ok(Self { + client_id, + resource, + http_client, + }) + } + + fn resource_from_scopes(scopes: &[String]) -> Result { + let [scope] = scopes else { + return Err(Error::InvalidInput { + message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource" + .to_string(), + }); + }; + + Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string()) + } +} + +#[async_trait] +impl TokenSource for AzureImdsSource { + async fn fetch_token(&self) -> Result { + let mut url = format!( + "{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}", + urlencoding::encode(&self.resource), + ); + if let Some(cid) = self.client_id.as_deref() { + url.push_str(&format!("&client_id={}", urlencoding::encode(cid))); + } + + let resp = self + .http_client + .get(&url) + .header("Metadata", "true") + .send() + .await + .map_err(|e| Error::Runtime { + message: format!("Azure IMDS request failed: {e}"), + })?; + + if !resp.status().is_success() { + return Err(Error::Runtime { + message: format!( + "Azure IMDS returned status {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ), + }); + } + + resp.json().await.map_err(|e| Error::Runtime { + message: format!("Failed to parse IMDS token response: {e}"), + }) + } +} + +/// OAuth header provider that manages the full token lifecycle. +/// +/// Implements [`HeaderProvider`] to inject `Authorization: Bearer ` +/// headers into every LanceDB request, with automatic token refresh. +pub struct OAuthHeaderProvider { + token_source: Box, + token_state: Arc>, + refresh_buffer: Duration, +} + +impl std::fmt::Debug for OAuthHeaderProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OAuthHeaderProvider") + .field("token_source", &self.token_source) + .finish() + } +} + +impl OAuthHeaderProvider { + /// Create a new OAuth header provider from configuration. + pub fn new(config: OAuthConfig) -> Result { + let OAuthConfig { + issuer_url, + client_id, + client_secret, + scopes, + flow, + refresh_buffer_secs, + } = config; + + if scopes.is_empty() { + return Err(Error::InvalidInput { + message: "At least one OAuth scope is required".to_string(), + }); + } + + let refresh_buffer = + Duration::from_secs(refresh_buffer_secs.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS)); + let token_source: Box = match flow { + OAuthFlow::ClientCredentials => Box::new(ClientCredentialsSource::new( + issuer_url, + client_id, + client_secret, + scopes, + )?), + OAuthFlow::AzureManagedIdentity { client_id } => { + Box::new(AzureImdsSource::new(scopes, client_id)?) + } + }; + + Ok(Self { + token_source, + token_state: Arc::new(RwLock::new(TokenState::new())), + refresh_buffer, + }) + } + + /// Get a valid access token, refreshing if necessary. + async fn get_valid_token(&self) -> Result { + // Fast path: check if current token is still valid + { + let state = self.token_state.read().await; + if !state.is_expired(self.refresh_buffer) + && let Some(ref token) = state.access_token + { + return Ok(token.clone()); + } + } + + // Slow path: acquire or refresh token + let mut state = self.token_state.write().await; + + // Double-check after acquiring write lock + if !state.is_expired(self.refresh_buffer) + && let Some(ref token) = state.access_token + { + return Ok(token.clone()); + } + + debug!("Acquiring new OAuth token via {:?}", self.token_source); + let resp = self.token_source.fetch_token().await?; + + state.update(&resp); + Ok(resp.access_token) + } +} + +#[async_trait] impl HeaderProvider for OAuthHeaderProvider { async fn get_headers(&self) -> Result> { let token = self.get_valid_token().await?; @@ -585,16 +631,15 @@ mod tests { #[test] fn test_scopes_string() { - let config = OAuthConfig { - issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), - client_id: "app-id".to_string(), - client_secret: Some("secret".to_string()), - scopes: vec!["scope1".to_string(), "scope2".to_string()], - flow: OAuthFlow::ClientCredentials, - refresh_buffer_secs: None, - }; - let provider = OAuthHeaderProvider::new(config).unwrap(); - assert_eq!(provider.scopes_string(), "scope1 scope2"); + let source = ClientCredentialsSource::new( + "https://login.microsoftonline.com/tenant/v2.0".to_string(), + "app-id".to_string(), + Some("secret".to_string()), + vec!["scope1".to_string(), "scope2".to_string()], + ) + .unwrap(); + + assert_eq!(source.scopes_string(), "scope1 scope2"); } #[test] @@ -614,31 +659,36 @@ mod tests { } #[test] - fn test_managed_identity_resource_from_default_scope() { + fn test_oauth_header_provider_debug_redacts_client_secret() { let config = OAuthConfig { - issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), - client_id: "app-id".to_string(), - client_secret: None, - scopes: vec!["api://test/.default".to_string()], - flow: OAuthFlow::AzureManagedIdentity { client_id: None }, + issuer_url: "https://issuer.example.com".to_string(), + client_id: "client-id".to_string(), + client_secret: Some("super-secret".to_string()), + scopes: vec!["scope".to_string()], + flow: OAuthFlow::ClientCredentials, refresh_buffer_secs: None, }; + let provider = OAuthHeaderProvider::new(config).unwrap(); - assert_eq!(provider.managed_identity_resource().unwrap(), "api://test"); + let debug = format!("{provider:?}"); + assert!(!debug.contains("super-secret")); + assert!(debug.contains("client_secret: \"\"")); + } + + #[test] + fn test_managed_identity_resource_from_default_scope() { + assert_eq!( + AzureImdsSource::resource_from_scopes(&["api://test/.default".to_string()]).unwrap(), + "api://test" + ); } #[test] fn test_managed_identity_resource_without_default_suffix() { - let config = OAuthConfig { - issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(), - client_id: "app-id".to_string(), - client_secret: None, - scopes: vec!["api://test".to_string()], - flow: OAuthFlow::AzureManagedIdentity { client_id: None }, - refresh_buffer_secs: None, - }; - let provider = OAuthHeaderProvider::new(config).unwrap(); - assert_eq!(provider.managed_identity_resource().unwrap(), "api://test"); + assert_eq!( + AzureImdsSource::resource_from_scopes(&["api://test".to_string()]).unwrap(), + "api://test" + ); } #[test] @@ -660,17 +710,15 @@ mod tests { #[tokio::test] async fn test_token_endpoint_requires_discovery_success() { let (issuer_url, server) = spawn_discovery_error_server().await; - let config = OAuthConfig { + let source = ClientCredentialsSource::new( issuer_url, - client_id: "client-id".to_string(), - client_secret: Some("secret".to_string()), - scopes: vec!["scope".to_string()], - flow: OAuthFlow::ClientCredentials, - refresh_buffer_secs: None, - }; - let provider = OAuthHeaderProvider::new(config).unwrap(); + "client-id".to_string(), + Some("secret".to_string()), + vec!["scope".to_string()], + ) + .unwrap(); - let err = provider.get_token_endpoint().await.unwrap_err(); + let err = source.get_token_endpoint().await.unwrap_err(); assert!(matches!( err, Error::Runtime { message }