mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-23 22:20:40 +00:00
Compare commits
2 Commits
v0.31.0-be
...
jack/idp-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f08dad782 | ||
|
|
8f55d9b54f |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5383,6 +5383,7 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"url",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
29
docs/src/js/enumerations/OAuthFlowType.md
Normal file
29
docs/src/js/enumerations/OAuthFlowType.md
Normal file
@@ -0,0 +1,29 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / OAuthFlowType
|
||||
|
||||
# Enumeration: OAuthFlowType
|
||||
|
||||
OAuth authentication flow types.
|
||||
|
||||
## Enumeration Members
|
||||
|
||||
### AzureManagedIdentity
|
||||
|
||||
```ts
|
||||
AzureManagedIdentity: "azure_managed_identity";
|
||||
```
|
||||
|
||||
Azure Managed Identity via IMDS.
|
||||
|
||||
***
|
||||
|
||||
### ClientCredentials
|
||||
|
||||
```ts
|
||||
ClientCredentials: "client_credentials";
|
||||
```
|
||||
|
||||
Client Credentials grant (service-to-service / M2M).
|
||||
@@ -12,6 +12,7 @@
|
||||
## Enumerations
|
||||
|
||||
- [FullTextQueryType](enumerations/FullTextQueryType.md)
|
||||
- [OAuthFlowType](enumerations/OAuthFlowType.md)
|
||||
- [Occur](enumerations/Occur.md)
|
||||
- [Operator](enumerations/Operator.md)
|
||||
|
||||
@@ -85,6 +86,8 @@
|
||||
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
|
||||
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
|
||||
- [MergeResult](interfaces/MergeResult.md)
|
||||
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
|
||||
- [OAuthConfig](interfaces/OAuthConfig.md)
|
||||
- [OpenTableOptions](interfaces/OpenTableOptions.md)
|
||||
- [OptimizeOptions](interfaces/OptimizeOptions.md)
|
||||
- [OptimizeStats](interfaces/OptimizeStats.md)
|
||||
|
||||
@@ -64,6 +64,19 @@ client used by manifest-enabled native connections.
|
||||
|
||||
***
|
||||
|
||||
### oauthConfig?
|
||||
|
||||
```ts
|
||||
optional oauthConfig: NativeOAuthConfig;
|
||||
```
|
||||
|
||||
(For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
and refresh are handled entirely in Rust. TypeScript users should pass
|
||||
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
|
||||
***
|
||||
|
||||
### readConsistencyInterval?
|
||||
|
||||
```ts
|
||||
|
||||
88
docs/src/js/interfaces/NativeOAuthConfig.md
Normal file
88
docs/src/js/interfaces/NativeOAuthConfig.md
Normal file
@@ -0,0 +1,88 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / NativeOAuthConfig
|
||||
|
||||
# Interface: NativeOAuthConfig
|
||||
|
||||
OAuth configuration for LanceDB authentication.
|
||||
|
||||
This is the generated napi-rs binding shape. TypeScript users should prefer
|
||||
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
|
||||
## Properties
|
||||
|
||||
### clientId
|
||||
|
||||
```ts
|
||||
clientId: string;
|
||||
```
|
||||
|
||||
Application / Client ID.
|
||||
|
||||
***
|
||||
|
||||
### clientSecret?
|
||||
|
||||
```ts
|
||||
optional clientSecret: string;
|
||||
```
|
||||
|
||||
Client secret (required for client_credentials).
|
||||
|
||||
***
|
||||
|
||||
### flow?
|
||||
|
||||
```ts
|
||||
optional flow: string;
|
||||
```
|
||||
|
||||
Authentication flow: "client_credentials" or "azure_managed_identity"
|
||||
|
||||
***
|
||||
|
||||
### issuerUrl
|
||||
|
||||
```ts
|
||||
issuerUrl: string;
|
||||
```
|
||||
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
|
||||
***
|
||||
|
||||
### managedIdentityClientId?
|
||||
|
||||
```ts
|
||||
optional managedIdentityClientId: string;
|
||||
```
|
||||
|
||||
Client ID for user-assigned managed identity (azure_managed_identity).
|
||||
|
||||
***
|
||||
|
||||
### refreshBufferSecs?
|
||||
|
||||
```ts
|
||||
optional refreshBufferSecs: number;
|
||||
```
|
||||
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
***
|
||||
|
||||
### scopes
|
||||
|
||||
```ts
|
||||
scopes: string[];
|
||||
```
|
||||
|
||||
OAuth scopes to request. For Azure managed identity, exactly one scope
|
||||
or resource is required. For example: `["api://{app_id}/.default"]`
|
||||
111
docs/src/js/interfaces/OAuthConfig.md
Normal file
111
docs/src/js/interfaces/OAuthConfig.md
Normal file
@@ -0,0 +1,111 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / OAuthConfig
|
||||
|
||||
# Interface: OAuthConfig
|
||||
|
||||
OAuth configuration for LanceDB authentication.
|
||||
|
||||
This is the public TypeScript OAuth configuration type. The generated
|
||||
`NativeOAuthConfig` type has the same runtime shape but is an implementation
|
||||
detail of the napi-rs binding.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
This config is passed through to Rust via napi-rs.
|
||||
|
||||
## Examples
|
||||
|
||||
```typescript
|
||||
const config: OAuthConfig = {
|
||||
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
clientId: "app-id",
|
||||
clientSecret: "secret",
|
||||
scopes: ["api://lancedb-api/.default"],
|
||||
};
|
||||
```
|
||||
|
||||
```typescript
|
||||
const config: OAuthConfig = {
|
||||
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
clientId: "app-id",
|
||||
scopes: ["api://lancedb-api/.default"],
|
||||
flow: OAuthFlowType.AzureManagedIdentity,
|
||||
};
|
||||
```
|
||||
|
||||
## Properties
|
||||
|
||||
### clientId
|
||||
|
||||
```ts
|
||||
clientId: string;
|
||||
```
|
||||
|
||||
Application / Client ID.
|
||||
|
||||
***
|
||||
|
||||
### clientSecret?
|
||||
|
||||
```ts
|
||||
optional clientSecret: string;
|
||||
```
|
||||
|
||||
Client secret (required for ClientCredentials).
|
||||
|
||||
***
|
||||
|
||||
### flow?
|
||||
|
||||
```ts
|
||||
optional flow: OAuthFlowType;
|
||||
```
|
||||
|
||||
Authentication flow (default: ClientCredentials).
|
||||
|
||||
***
|
||||
|
||||
### issuerUrl
|
||||
|
||||
```ts
|
||||
issuerUrl: string;
|
||||
```
|
||||
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
|
||||
***
|
||||
|
||||
### managedIdentityClientId?
|
||||
|
||||
```ts
|
||||
optional managedIdentityClientId: string;
|
||||
```
|
||||
|
||||
Client ID for user-assigned managed identity (AzureManagedIdentity).
|
||||
|
||||
***
|
||||
|
||||
### refreshBufferSecs?
|
||||
|
||||
```ts
|
||||
optional refreshBufferSecs: number;
|
||||
```
|
||||
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
***
|
||||
|
||||
### scopes
|
||||
|
||||
```ts
|
||||
scopes: string[];
|
||||
```
|
||||
|
||||
OAuth scopes to request.
|
||||
For Azure managed identity, exactly one scope or resource is required.
|
||||
For example: `["api://{app_id}/.default"]`
|
||||
@@ -52,6 +52,7 @@ export {
|
||||
SplitHashOptions,
|
||||
SplitSequentialOptions,
|
||||
ShuffleOptions,
|
||||
OAuthConfig as NativeOAuthConfig,
|
||||
} from "./native.js";
|
||||
|
||||
export {
|
||||
@@ -130,6 +131,8 @@ export {
|
||||
TokenResponse,
|
||||
} from "./header";
|
||||
|
||||
export { OAuthConfig, OAuthFlowType } from "./oauth";
|
||||
|
||||
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
|
||||
76
nodejs/lancedb/oauth.ts
Normal file
76
nodejs/lancedb/oauth.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
/**
|
||||
* OAuth authentication flow types.
|
||||
*/
|
||||
export enum OAuthFlowType {
|
||||
/** Client Credentials grant (service-to-service / M2M). */
|
||||
ClientCredentials = "client_credentials",
|
||||
/** Azure Managed Identity via IMDS. */
|
||||
AzureManagedIdentity = "azure_managed_identity",
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth configuration for LanceDB authentication.
|
||||
*
|
||||
* This is the public TypeScript OAuth configuration type. The generated
|
||||
* `NativeOAuthConfig` type has the same runtime shape but is an implementation
|
||||
* detail of the napi-rs binding.
|
||||
*
|
||||
* All token acquisition and refresh is handled in the Rust layer.
|
||||
* This config is passed through to Rust via napi-rs.
|
||||
*
|
||||
* @example Client Credentials (service-to-service):
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* clientSecret: "secret",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* };
|
||||
* ```
|
||||
*
|
||||
* @example Azure Managed Identity:
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* flow: OAuthFlowType.AzureManagedIdentity,
|
||||
* };
|
||||
* ```
|
||||
*/
|
||||
export interface OAuthConfig {
|
||||
/**
|
||||
* OIDC issuer URL or OAuth authority URL.
|
||||
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
*/
|
||||
issuerUrl: string;
|
||||
|
||||
/** Application / Client ID. */
|
||||
clientId: string;
|
||||
|
||||
/**
|
||||
* OAuth scopes to request.
|
||||
* For Azure managed identity, exactly one scope or resource is required.
|
||||
* For example: `["api://{app_id}/.default"]`
|
||||
*/
|
||||
scopes: string[];
|
||||
|
||||
/** Authentication flow (default: ClientCredentials). */
|
||||
flow?: OAuthFlowType;
|
||||
|
||||
/** Client secret (required for ClientCredentials). */
|
||||
clientSecret?: string;
|
||||
|
||||
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
|
||||
managedIdentityClientId?: string;
|
||||
|
||||
/**
|
||||
* Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
* Keep this well below the token TTL; if it is greater than or equal to
|
||||
* the TTL, each request refreshes the token.
|
||||
*/
|
||||
refreshBufferSecs?: number;
|
||||
}
|
||||
@@ -112,6 +112,12 @@ impl Connection {
|
||||
|
||||
builder = builder.client_config(rust_config);
|
||||
|
||||
if let Some(oauth_config) = options.oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig =
|
||||
oauth_config.try_into().default_error()?;
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
|
||||
if let Some(api_key) = options.api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
}
|
||||
|
||||
@@ -65,6 +65,11 @@ pub struct ConnectionOptions {
|
||||
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
|
||||
/// for testing purposes.
|
||||
pub host_override: Option<String>,
|
||||
/// (For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
/// authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
/// and refresh are handled entirely in Rust. TypeScript users should pass
|
||||
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
pub oauth_config: Option<remote::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use lancedb::error::Error;
|
||||
use napi_derive::*;
|
||||
|
||||
/// Timeout configuration for remote HTTP client.
|
||||
@@ -140,6 +141,84 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
///
|
||||
/// This is the generated napi-rs binding shape. TypeScript users should prefer
|
||||
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
///
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
#[napi(object)]
|
||||
#[derive(Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
/// OAuth scopes to request. For Azure managed identity, exactly one scope
|
||||
/// or resource is required. For example: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
/// Authentication flow: "client_credentials" or "azure_managed_identity"
|
||||
pub flow: Option<String>,
|
||||
/// Client secret (required for client_credentials).
|
||||
pub client_secret: Option<String>,
|
||||
/// Client ID for user-assigned managed identity (azure_managed_identity).
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
/// Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
/// Keep this well below the token TTL; if it is greater than or equal to
|
||||
/// the TTL, each request refreshes the token.
|
||||
pub refresh_buffer_secs: Option<u32>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthConfig")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("flow", &self.flow)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_deref().map(|_| "<redacted>"),
|
||||
)
|
||||
.field(
|
||||
"managed_identity_client_id",
|
||||
&self.managed_identity_client_id,
|
||||
)
|
||||
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(config: OAuthConfig) -> Result<Self, Self::Error> {
|
||||
use lancedb::remote::oauth::OAuthFlow;
|
||||
|
||||
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
|
||||
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: config.managed_identity_client_id,
|
||||
},
|
||||
other => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("Unknown OAuth flow type: {other}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
issuer_url: config.issuer_url,
|
||||
client_id: config.client_id,
|
||||
client_secret: config.client_secret,
|
||||
scopes: config.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
@@ -156,3 +235,45 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_unknown_oauth_flow_returns_invalid_input() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: Some("typo".to_string()),
|
||||
client_secret: None,
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = lancedb::remote::oauth::OAuthConfig::try_from(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "Unknown OAuth flow type: typo"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_config_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: Some("client_credentials".to_string()),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let debug = format!("{config:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,6 +89,8 @@ def connect(
|
||||
If presented, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||
OAuth configuration is currently supported only by ``connect_async``;
|
||||
synchronous LanceDB Cloud connections require an API key.
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
@@ -340,6 +342,7 @@ async def connect_async(
|
||||
session: Optional[Session] = None,
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config=None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -389,6 +392,10 @@ async def connect_async(
|
||||
namespace_client_properties : dict, optional
|
||||
Additional directory namespace client properties to use with
|
||||
``manifest_enabled=True``.
|
||||
oauth_config : OAuthConfig, optional
|
||||
OAuth configuration for LanceDB Cloud/Enterprise. This is supported by
|
||||
``connect_async`` only; synchronous ``connect`` uses API key
|
||||
authentication for ``db://`` URIs.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -435,6 +442,7 @@ async def connect_async(
|
||||
session,
|
||||
manifest_enabled,
|
||||
namespace_client_properties,
|
||||
oauth_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -280,6 +280,7 @@ async def connect(
|
||||
session: Optional[Session],
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config: Optional[Any] = None,
|
||||
) -> Connection: ...
|
||||
|
||||
class RecordBatchStream:
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import List, Optional
|
||||
from lancedb import __version__
|
||||
|
||||
from .header import HeaderProvider
|
||||
from .oauth import OAuthConfig, OAuthFlowType
|
||||
|
||||
__all__ = [
|
||||
"TimeoutConfig",
|
||||
@@ -16,6 +17,8 @@ __all__ = [
|
||||
"TlsConfig",
|
||||
"ClientConfig",
|
||||
"HeaderProvider",
|
||||
"OAuthConfig",
|
||||
"OAuthFlowType",
|
||||
]
|
||||
|
||||
|
||||
|
||||
75
python/python/lancedb/remote/oauth.py
Normal file
75
python/python/lancedb/remote/oauth.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class OAuthFlowType(str, Enum):
|
||||
"""OAuth authentication flow types."""
|
||||
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
"""Client Credentials grant (service-to-service / M2M)."""
|
||||
|
||||
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
|
||||
"""Azure Managed Identity via IMDS."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthConfig:
|
||||
"""OAuth configuration for LanceDB authentication.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
This config is passed through to Rust via PyO3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
issuer_url : str
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
|
||||
client_id : str
|
||||
Application / Client ID.
|
||||
scopes : List[str]
|
||||
OAuth scopes to request.
|
||||
For Azure managed identity, exactly one scope or resource is required.
|
||||
For example: ``["api://{app_id}/.default"]``
|
||||
flow : OAuthFlowType
|
||||
Authentication flow to use. Default: CLIENT_CREDENTIALS.
|
||||
client_secret : Optional[str]
|
||||
Client secret (required for CLIENT_CREDENTIALS).
|
||||
managed_identity_client_id : Optional[str]
|
||||
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
|
||||
refresh_buffer_secs : Optional[int]
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Client Credentials (service-to-service):
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... client_secret="secret",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... )
|
||||
|
||||
Azure Managed Identity:
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
|
||||
... )
|
||||
"""
|
||||
|
||||
issuer_url: str
|
||||
client_id: str
|
||||
scopes: List[str]
|
||||
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
|
||||
client_secret: Optional[str] = field(default=None, repr=False)
|
||||
managed_identity_client_id: Optional[str] = None
|
||||
refresh_buffer_secs: Optional[int] = None
|
||||
@@ -539,7 +539,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn connect(
|
||||
py: Python<'_>,
|
||||
@@ -553,6 +553,7 @@ pub fn connect(
|
||||
session: Option<crate::session::Session>,
|
||||
manifest_enabled: bool,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
oauth_config: Option<crate::oauth::PyOAuthConfig>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
@@ -582,6 +583,11 @@ pub fn connect(
|
||||
if let Some(client_config) = client_config {
|
||||
builder = builder.client_config(client_config.into());
|
||||
}
|
||||
if let Some(oauth_config) = oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig =
|
||||
oauth_config.try_into().infer_error()?;
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
if let Some(session) = session {
|
||||
builder = builder.session(session.inner.clone());
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ pub mod expr;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod namespace;
|
||||
pub mod oauth;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod runtime;
|
||||
|
||||
72
python/src/oauth.rs
Normal file
72
python/src/oauth.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use pyo3::FromPyObject;
|
||||
|
||||
use lancedb::error::Error;
|
||||
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
|
||||
|
||||
/// Python-side OAuth configuration, extracted via FromPyObject.
|
||||
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyOAuthConfig {
|
||||
pub issuer_url: String,
|
||||
pub client_id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub flow: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl TryFrom<PyOAuthConfig> for OAuthConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(py: PyOAuthConfig) -> Result<Self, Self::Error> {
|
||||
let flow = match py.flow.as_str() {
|
||||
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: py.managed_identity_client_id,
|
||||
},
|
||||
other => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("Unknown OAuth flow type: {other}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
issuer_url: py.issuer_url,
|
||||
client_id: py.client_id,
|
||||
client_secret: py.client_secret,
|
||||
scopes: py.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: py.refresh_buffer_secs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_unknown_oauth_flow_returns_invalid_input() {
|
||||
let config = PyOAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: "typo".to_string(),
|
||||
client_secret: None,
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = OAuthConfig::try_from(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "Unknown OAuth flow type: typo"
|
||||
));
|
||||
}
|
||||
}
|
||||
33
python/tests/test_oauth.py
Normal file
33
python/tests/test_oauth.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_oauth_module():
|
||||
oauth_path = (
|
||||
Path(__file__).parents[1] / "python" / "lancedb" / "remote" / "oauth.py"
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location("lancedb_remote_oauth", oauth_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_oauth_config_repr_redacts_client_secret():
|
||||
oauth = _load_oauth_module()
|
||||
|
||||
config = oauth.OAuthConfig(
|
||||
issuer_url="https://issuer.example.com",
|
||||
client_id="client-id",
|
||||
scopes=["scope"],
|
||||
client_secret="super-secret",
|
||||
)
|
||||
|
||||
rendered = repr(config)
|
||||
assert "super-secret" not in rendered
|
||||
assert "client_secret" not in rendered
|
||||
@@ -75,6 +75,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [
|
||||
"stream",
|
||||
], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
# OAuth dependencies (used by remote feature)
|
||||
urlencoding = { version = "2", optional = true }
|
||||
uuid = { version = "1.7.0", features = ["v4", "v5"] }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
@@ -93,6 +95,7 @@ semver = { workspace = true }
|
||||
anyhow = "1"
|
||||
tempfile = "3.5.0"
|
||||
random_word = { version = "0.4.3", features = ["en"] }
|
||||
tokio = { version = "1.23", features = ["io-util", "macros", "net", "rt-multi-thread"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
walkdir = "2"
|
||||
aws-sdk-dynamodb = { version = "1.55.0" }
|
||||
@@ -129,7 +132,7 @@ huggingface = [
|
||||
"lance-namespace-impls/dir-huggingface",
|
||||
]
|
||||
dynamodb = ["lance/dynamodb", "aws"]
|
||||
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||
remote = ["dep:reqwest", "dep:http", "dep:urlencoding", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||
|
||||
@@ -576,6 +576,10 @@ impl Connection {
|
||||
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
|
||||
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
|
||||
/// For RemoteDatabase, it is the equivalent RestNamespace.
|
||||
///
|
||||
/// Remote connections using dynamic headers, including OAuth, are not
|
||||
/// currently supported because the namespace client only accepts static
|
||||
/// headers.
|
||||
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
|
||||
self.internal.namespace_client().await
|
||||
}
|
||||
@@ -584,6 +588,10 @@ impl Connection {
|
||||
/// Returns (impl_type, properties) where:
|
||||
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||
/// - properties: configuration properties for the namespace
|
||||
///
|
||||
/// Remote connections using dynamic headers, including OAuth, are not
|
||||
/// currently supported because the namespace client config only carries
|
||||
/// static headers.
|
||||
pub async fn namespace_client_config(
|
||||
&self,
|
||||
) -> Result<(String, std::collections::HashMap<String, String>)> {
|
||||
@@ -661,6 +669,8 @@ pub struct ConnectRequest {
|
||||
pub struct ConnectBuilder {
|
||||
request: ConnectRequest,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: Option<crate::remote::oauth::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
@@ -682,6 +692,8 @@ impl ConnectBuilder {
|
||||
session: None,
|
||||
},
|
||||
embedding_registry: None,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -770,6 +782,19 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
|
||||
///
|
||||
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
|
||||
/// from the given config and sets it as the header provider. OAuth cannot
|
||||
/// be combined with an API key.
|
||||
///
|
||||
/// Token acquisition and refresh are handled entirely in Rust.
|
||||
#[cfg(feature = "remote")]
|
||||
pub fn oauth_config(mut self, config: crate::remote::oauth::OAuthConfig) -> Self {
|
||||
self.oauth_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
@@ -915,9 +940,42 @@ impl ConnectBuilder {
|
||||
let region = options.region.ok_or_else(|| Error::InvalidInput {
|
||||
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
|
||||
// When OAuth is configured, api_key is not required
|
||||
let api_key = match (&self.oauth_config, &options.api_key) {
|
||||
(Some(_), None) => String::new(),
|
||||
(Some(_), Some(_)) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
(None, Some(key)) => key.clone(),
|
||||
(None, None) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if self.oauth_config.is_some() && self.request.client_config.header_provider.is_some() {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut client_config = self.request.client_config;
|
||||
|
||||
// Apply OAuth header provider if configured
|
||||
if let Some(oauth_config) = self.oauth_config {
|
||||
let provider = crate::remote::oauth::OAuthHeaderProvider::new(oauth_config)?;
|
||||
client_config.header_provider =
|
||||
Some(Arc::new(provider) as Arc<dyn crate::remote::client::HeaderProvider>);
|
||||
}
|
||||
|
||||
let storage_options = StorageOptions(options.storage_options.clone());
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
@@ -925,7 +983,7 @@ impl ConnectBuilder {
|
||||
&api_key,
|
||||
®ion,
|
||||
options.host_override,
|
||||
self.request.client_config,
|
||||
client_config,
|
||||
storage_options.into(),
|
||||
self.request.read_consistency_interval,
|
||||
)?);
|
||||
@@ -1234,6 +1292,83 @@ mod tests {
|
||||
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[tokio::test]
|
||||
async fn test_connect_rejects_api_key_with_oauth_config() {
|
||||
let oauth_config = crate::remote::oauth::OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: crate::remote::oauth::OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||
.region("us-east-1")
|
||||
.api_key("my-api-key")
|
||||
.oauth_config(oauth_config)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidInput { message })
|
||||
if message
|
||||
== "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud" =>
|
||||
{}
|
||||
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||
Ok(_) => panic!("expected api_key and oauth_config to be rejected"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[tokio::test]
|
||||
async fn test_connect_rejects_header_provider_with_oauth_config() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl crate::remote::HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let oauth_config = crate::remote::oauth::OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: crate::remote::oauth::OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let client_config = crate::remote::ClientConfig {
|
||||
header_provider: Some(
|
||||
Arc::new(TestHeaderProvider) as Arc<dyn crate::remote::HeaderProvider>
|
||||
),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||
.region("us-east-1")
|
||||
.client_config(client_config)
|
||||
.oauth_config(oauth_config)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidInput { message })
|
||||
if message
|
||||
== "oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud" =>
|
||||
{}
|
||||
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||
Ok(_) => panic!("expected header_provider and oauth_config to be rejected"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn test_connect_relative() {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod db;
|
||||
pub mod oauth;
|
||||
mod retry;
|
||||
pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
@@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
||||
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
||||
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};
|
||||
|
||||
@@ -459,12 +459,14 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
config: &ClientConfig,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
if !api_key.is_empty() {
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
if region == "local" {
|
||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||
headers.insert(
|
||||
@@ -1037,6 +1039,35 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_headers_skip_empty_api_key() {
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"",
|
||||
"us-east-1",
|
||||
"db",
|
||||
false,
|
||||
&RemoteOptions(HashMap::new()),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!headers.contains_key("x-api-key"));
|
||||
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"api-key",
|
||||
"us-east-1",
|
||||
"db",
|
||||
false,
|
||||
&RemoteOptions(HashMap::new()),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(headers.get("x-api-key").unwrap(), "api-key");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_config_with_header_provider() {
|
||||
let mut headers = HashMap::new();
|
||||
|
||||
@@ -194,6 +194,7 @@ pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
uri: String,
|
||||
/// Headers to pass to the namespace client for authentication
|
||||
namespace_headers: HashMap<String, String>,
|
||||
has_dynamic_header_provider: bool,
|
||||
/// TLS configuration for mTLS support
|
||||
tls_config: Option<super::client::TlsConfig>,
|
||||
}
|
||||
@@ -247,6 +248,7 @@ impl RemoteDatabase {
|
||||
table_cache,
|
||||
uri: uri.to_owned(),
|
||||
namespace_headers,
|
||||
has_dynamic_header_provider: client_config.header_provider.is_some(),
|
||||
tls_config: client_config.tls_config,
|
||||
})
|
||||
}
|
||||
@@ -271,6 +273,7 @@ mod test_utils {
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: HashMap::new(),
|
||||
has_dynamic_header_provider: false,
|
||||
tls_config: None,
|
||||
}
|
||||
}
|
||||
@@ -286,6 +289,7 @@ mod test_utils {
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: config.extra_headers.clone(),
|
||||
has_dynamic_header_provider: config.header_provider.is_some(),
|
||||
tls_config: config.tls_config.clone(),
|
||||
}
|
||||
}
|
||||
@@ -756,10 +760,17 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
|
||||
async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
|
||||
if self.has_dynamic_header_provider {
|
||||
return Err(Error::NotSupported {
|
||||
message:
|
||||
"Cannot create a namespace client when dynamic headers are configured; use LanceDB connection namespace methods instead"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Create a RestNamespace pointing to the same remote host with the same authentication headers
|
||||
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
|
||||
.delimiter(&self.client.id_delimiter)
|
||||
// TODO: support header provider
|
||||
.headers(self.namespace_headers.clone());
|
||||
|
||||
// Apply mTLS configuration if present
|
||||
@@ -781,6 +792,14 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||
if self.has_dynamic_header_provider {
|
||||
return Err(Error::NotSupported {
|
||||
message:
|
||||
"Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert("uri".to_string(), self.client.host().to_string());
|
||||
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
|
||||
@@ -1702,6 +1721,51 @@ mod tests {
|
||||
assert!(namespace_client.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_namespace_client_rejects_dynamic_headers() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
|_| {
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": []}"#)
|
||||
.unwrap()
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
match conn.namespace_client().await {
|
||||
Err(Error::NotSupported { message })
|
||||
if message.contains("dynamic headers are configured") => {}
|
||||
Err(err) => panic!("expected NotSupported, got {err:?}"),
|
||||
Ok(_) => panic!("expected namespace_client to reject dynamic headers"),
|
||||
}
|
||||
|
||||
match conn.namespace_client_config().await {
|
||||
Err(Error::NotSupported { message })
|
||||
if message.contains("dynamic headers are configured") => {}
|
||||
Err(err) => panic!("expected NotSupported, got {err:?}"),
|
||||
Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
|
||||
mod rest_adapter_integration {
|
||||
use super::*;
|
||||
|
||||
857
rust/lancedb/src/remote/oauth.rs
Normal file
857
rust/lancedb/src/remote/oauth.rs
Normal file
@@ -0,0 +1,857 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use log::debug;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::client::HeaderProvider;
|
||||
|
||||
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
const DEFAULT_TOKEN_TTL_SECS: u64 = 3600;
|
||||
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
|
||||
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
|
||||
|
||||
/// OAuth authentication flow configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OAuthFlow {
|
||||
/// Client Credentials grant (service-to-service / M2M).
|
||||
/// Requires `client_secret` in [`OAuthConfig`].
|
||||
ClientCredentials,
|
||||
|
||||
/// Azure Managed Identity via IMDS.
|
||||
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
|
||||
AzureManagedIdentity {
|
||||
/// Client ID for user-assigned managed identity.
|
||||
/// Omit for system-assigned managed identity.
|
||||
client_id: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
///
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
/// Python and TypeScript bindings expose this as a plain config object.
|
||||
#[derive(Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
|
||||
/// Client secret (required for `ClientCredentials`, optional for others).
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
/// OAuth scopes to request.
|
||||
/// For Azure managed identity, exactly one scope or resource is required.
|
||||
/// For example: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
|
||||
/// Authentication flow to use.
|
||||
pub flow: OAuthFlow,
|
||||
|
||||
/// Seconds before token expiry to trigger proactive refresh (default: 300).
|
||||
/// Keep this well below the token TTL; if it is greater than or equal to
|
||||
/// the TTL, each request refreshes the token.
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthConfig")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_deref().map(|_| "<redacted>"),
|
||||
)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("flow", &self.flow)
|
||||
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// -- OIDC Discovery --
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OidcDiscovery {
|
||||
token_endpoint: String,
|
||||
}
|
||||
|
||||
// -- Token Response --
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
/// Token lifetime in seconds.
|
||||
/// Some providers (Azure IMDS) return this as a string, so we accept both.
|
||||
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TokenResponse {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TokenResponse")
|
||||
.field("access_token", &"<redacted>")
|
||||
.field("expires_in", &self.expires_in)
|
||||
.field("token_type", &self.token_type)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_optional_u64_or_string<'de, D>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<Option<u64>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de;
|
||||
|
||||
struct U64OrString;
|
||||
impl<'de> de::Visitor<'de> for U64OrString {
|
||||
type Value = Option<u64>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("an integer, an integer-valued float, a numeric string, or null")
|
||||
}
|
||||
|
||||
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
|
||||
Ok(Some(v))
|
||||
}
|
||||
|
||||
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
|
||||
if v < 0 {
|
||||
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||
}
|
||||
Ok(Some(v as u64))
|
||||
}
|
||||
|
||||
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
|
||||
if !v.is_finite() || v < 0.0 || v.fract() != 0.0 || v > u64::MAX as f64 {
|
||||
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||
}
|
||||
Ok(Some(v as u64))
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
|
||||
v.parse::<u64>().map(Some).map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(U64OrString)
|
||||
}
|
||||
|
||||
// -- Internal Token State --
|
||||
|
||||
struct TokenState {
|
||||
access_token: Option<String>,
|
||||
expires_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl TokenState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
access_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_expired(&self, buffer: Duration) -> bool {
|
||||
match (self.access_token.as_ref(), self.expires_at) {
|
||||
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
|
||||
(None, _) => true,
|
||||
(Some(_), None) => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, resp: &TokenResponse) {
|
||||
self.access_token = Some(resp.access_token.clone());
|
||||
let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS);
|
||||
self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in));
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth header provider that manages the full token lifecycle.
|
||||
///
|
||||
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
|
||||
/// headers into every LanceDB request, with automatic token refresh.
|
||||
pub struct OAuthHeaderProvider {
|
||||
config: OAuthConfig,
|
||||
http_client: Client,
|
||||
token_state: Arc<RwLock<TokenState>>,
|
||||
/// Cached OIDC discovery document
|
||||
discovery: Arc<RwLock<Option<OidcDiscovery>>>,
|
||||
refresh_buffer: Duration,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthHeaderProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthHeaderProvider")
|
||||
.field("issuer_url", &self.config.issuer_url)
|
||||
.field("client_id", &self.config.client_id)
|
||||
.field("flow", &self.config.flow)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthHeaderProvider {
|
||||
/// Create a new OAuth header provider from configuration.
|
||||
pub fn new(config: OAuthConfig) -> Result<Self> {
|
||||
// Validate config upfront
|
||||
if matches!(config.flow, OAuthFlow::ClientCredentials) && config.client_secret.is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "client_secret is required for ClientCredentials flow".to_string(),
|
||||
});
|
||||
}
|
||||
if config.scopes.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "At least one OAuth scope is required".to_string(),
|
||||
});
|
||||
}
|
||||
if matches!(config.flow, OAuthFlow::AzureManagedIdentity { .. }) && config.scopes.len() != 1
|
||||
{
|
||||
return Err(Error::InvalidInput {
|
||||
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
Self::validate_issuer_transport(&config)?;
|
||||
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create HTTP client for OAuth: {e}"),
|
||||
})?;
|
||||
|
||||
let refresh_buffer = Duration::from_secs(
|
||||
config
|
||||
.refresh_buffer_secs
|
||||
.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS),
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
http_client,
|
||||
token_state: Arc::new(RwLock::new(TokenState::new())),
|
||||
discovery: Arc::new(RwLock::new(None)),
|
||||
refresh_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_issuer_transport(config: &OAuthConfig) -> Result<()> {
|
||||
if !matches!(config.flow, OAuthFlow::ClientCredentials) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let issuer = url::Url::parse(&config.issuer_url).map_err(|e| Error::InvalidInput {
|
||||
message: format!("Invalid OAuth issuer_url: {e}"),
|
||||
})?;
|
||||
|
||||
match issuer.scheme() {
|
||||
"https" => Ok(()),
|
||||
"http" if Self::is_loopback_issuer(&issuer) => Ok(()),
|
||||
_ => Err(Error::InvalidInput {
|
||||
message:
|
||||
"ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||
.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loopback_issuer(issuer: &url::Url) -> bool {
|
||||
let Some(host) = issuer.host_str() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
host.eq_ignore_ascii_case("localhost")
|
||||
|| host
|
||||
.parse::<IpAddr>()
|
||||
.map(|addr| addr.is_loopback())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing if necessary.
|
||||
async fn get_valid_token(&self) -> Result<String> {
|
||||
// Fast path: check if current token is still valid
|
||||
{
|
||||
let state = self.token_state.read().await;
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: acquire or refresh token
|
||||
let mut state = self.token_state.write().await;
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
debug!("Acquiring new OAuth token via {:?} flow", self.config.flow);
|
||||
let resp = self.acquire_token().await?;
|
||||
|
||||
state.update(&resp);
|
||||
Ok(resp.access_token)
|
||||
}
|
||||
|
||||
/// Acquire a new token using the configured flow.
|
||||
async fn acquire_token(&self) -> Result<TokenResponse> {
|
||||
match &self.config.flow {
|
||||
OAuthFlow::ClientCredentials => self.acquire_client_credentials().await,
|
||||
OAuthFlow::AzureManagedIdentity { client_id } => {
|
||||
self.acquire_managed_identity(client_id.as_deref()).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- OIDC Discovery --
|
||||
|
||||
async fn get_discovery(&self) -> Result<OidcDiscovery> {
|
||||
{
|
||||
let cached = self.discovery.read().await;
|
||||
if let Some(ref disc) = *cached {
|
||||
return Ok(OidcDiscovery {
|
||||
token_endpoint: disc.token_endpoint.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = self.discovery.write().await;
|
||||
// Double-check
|
||||
if let Some(ref disc) = *cache {
|
||||
return Ok(OidcDiscovery {
|
||||
token_endpoint: disc.token_endpoint.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let discovery_url = format!(
|
||||
"{}/.well-known/openid-configuration",
|
||||
self.config.issuer_url.trim_end_matches('/')
|
||||
);
|
||||
|
||||
debug!("Fetching OIDC discovery from {}", discovery_url);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to fetch OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"OIDC discovery failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
let result = OidcDiscovery {
|
||||
token_endpoint: disc.token_endpoint.clone(),
|
||||
};
|
||||
|
||||
*cache = Some(disc);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn get_token_endpoint(&self) -> Result<String> {
|
||||
self.get_discovery().await.map(|disc| disc.token_endpoint)
|
||||
}
|
||||
|
||||
fn scopes_string(&self) -> String {
|
||||
self.config.scopes.join(" ")
|
||||
}
|
||||
|
||||
fn managed_identity_resource(&self) -> Result<String> {
|
||||
let [scope] = self.config.scopes.as_slice() else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
|
||||
.to_string(),
|
||||
});
|
||||
};
|
||||
|
||||
Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string())
|
||||
}
|
||||
|
||||
// -- Client Credentials Flow --
|
||||
|
||||
async fn acquire_client_credentials(&self) -> Result<TokenResponse> {
|
||||
let client_secret = self
|
||||
.config
|
||||
.client_secret
|
||||
.as_ref()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "client_secret is required for ClientCredentials flow".to_string(),
|
||||
})?;
|
||||
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
|
||||
let params = [
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", &self.config.client_id),
|
||||
("client_secret", client_secret),
|
||||
("scope", &self.scopes_string()),
|
||||
];
|
||||
|
||||
self.post_token_request(&token_endpoint, ¶ms).await
|
||||
}
|
||||
|
||||
// -- Azure Managed Identity Flow --
|
||||
|
||||
async fn acquire_managed_identity(&self, mi_client_id: Option<&str>) -> Result<TokenResponse> {
|
||||
let resource = self.managed_identity_resource()?;
|
||||
|
||||
let mut url = format!(
|
||||
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
|
||||
urlencoding::encode(&resource),
|
||||
);
|
||||
if let Some(cid) = mi_client_id {
|
||||
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Metadata", "true")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Azure IMDS request failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Azure IMDS returned status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse IMDS token response: {e}"),
|
||||
})
|
||||
}
|
||||
|
||||
// -- Shared Helpers --
|
||||
|
||||
async fn post_token_request(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
params: &[(&str, &str)],
|
||||
) -> Result<TokenResponse> {
|
||||
let resp = self
|
||||
.http_client
|
||||
.post(endpoint)
|
||||
.form(params)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Token request to {endpoint} failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Token request failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse token response: {e}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for OAuthHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
let token = self.get_valid_token().await?;
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
format!("Bearer {token}"),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[test]
|
||||
fn test_token_state_expiry() {
|
||||
let mut state = TokenState::new();
|
||||
assert!(state.is_expired(Duration::from_secs(0)));
|
||||
|
||||
state.access_token = Some("tok".to_string());
|
||||
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
|
||||
assert!(!state.is_expired(Duration::from_secs(300)));
|
||||
assert!(state.is_expired(Duration::from_secs(601)));
|
||||
|
||||
state.expires_at = None;
|
||||
assert!(state.is_expired(Duration::from_secs(0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_state_uses_default_expiry() {
|
||||
let mut state = TokenState::new();
|
||||
let response = TokenResponse {
|
||||
access_token: "tok".to_string(),
|
||||
expires_in: None,
|
||||
token_type: None,
|
||||
};
|
||||
|
||||
state.update(&response);
|
||||
|
||||
assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1)));
|
||||
assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_accepts_float_expires_in() {
|
||||
let response: TokenResponse =
|
||||
serde_json::from_str(r#"{"access_token":"tok","expires_in":3600.0}"#).unwrap();
|
||||
|
||||
assert_eq!(response.expires_in, Some(3600));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_rejects_negative_expires_in() {
|
||||
let err =
|
||||
serde_json::from_str::<TokenResponse>(r#"{"access_token":"tok","expires_in":-1}"#)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("invalid expires_in value: -1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_debug_redacts_access_token() {
|
||||
let response = TokenResponse {
|
||||
access_token: "secret-token".to_string(),
|
||||
expires_in: Some(3600),
|
||||
token_type: Some("Bearer".to_string()),
|
||||
};
|
||||
|
||||
let debug = format!("{response:?}");
|
||||
assert!(!debug.contains("secret-token"));
|
||||
assert!(debug.contains("access_token: \"<redacted>\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scopes_string() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope1".to_string(), "scope2".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
assert_eq!(provider.scopes_string(), "scope1 scope2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_config_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let debug = format!("{config:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_resource_from_default_scope() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec!["api://test/.default".to_string()],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
assert_eq!(provider.managed_identity_resource().unwrap(), "api://test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_resource_without_default_suffix() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec!["api://test".to_string()],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
assert_eq!(provider.managed_identity_resource().unwrap(), "api://test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_rejects_multiple_scopes() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec![
|
||||
"api://test-a/.default".to_string(),
|
||||
"api://test-b/.default".to_string(),
|
||||
],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_endpoint_requires_discovery_success() {
|
||||
let (issuer_url, server) = spawn_discovery_error_server().await;
|
||||
let config = OAuthConfig {
|
||||
issuer_url,
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
|
||||
let err = provider.get_token_endpoint().await.unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::Runtime { message }
|
||||
if message.contains("OIDC discovery failed with status 503")
|
||||
));
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_credentials_requires_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_credentials_rejects_insecure_non_loopback_issuer() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "http://issuer.example.com".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = OAuthHeaderProvider::new(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_scopes_rejected() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec![],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_credentials_token_lifecycle() {
|
||||
let (issuer_url, token_requests, server) = spawn_oauth_server().await;
|
||||
let config = OAuthConfig {
|
||||
issuer_url,
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: Some(0),
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||
|
||||
provider.token_state.write().await.expires_at =
|
||||
Some(Instant::now() - Duration::from_secs(1));
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 2);
|
||||
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
async fn spawn_oauth_server() -> (String, Arc<AtomicUsize>, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let issuer_url = format!("http://{addr}");
|
||||
let token_requests = Arc::new(AtomicUsize::new(0));
|
||||
let server_token_requests = Arc::clone(&token_requests);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
for _ in 0..3 {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let (request_line, body) = read_http_request(&mut stream).await;
|
||||
|
||||
if request_line.starts_with("GET /.well-known/openid-configuration ") {
|
||||
let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#);
|
||||
write_json_response(&mut stream, "200 OK", &discovery).await;
|
||||
} else if request_line.starts_with("POST /token ") {
|
||||
assert!(body.contains("grant_type=client_credentials"));
|
||||
assert!(body.contains("client_id=client-id"));
|
||||
assert!(body.contains("client_secret=secret"));
|
||||
assert!(body.contains("scope=scope"));
|
||||
|
||||
let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let token = format!(
|
||||
r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"#
|
||||
);
|
||||
write_json_response(&mut stream, "200 OK", &token).await;
|
||||
} else {
|
||||
write_json_response(&mut stream, "404 Not Found", "{}").await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(issuer_url, token_requests, server)
|
||||
}
|
||||
|
||||
async fn spawn_discovery_error_server() -> (String, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let issuer_url = format!("http://{addr}");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let (request_line, _) = read_http_request(&mut stream).await;
|
||||
assert!(request_line.starts_with("GET /.well-known/openid-configuration "));
|
||||
write_json_response(&mut stream, "503 Service Unavailable", "{}").await;
|
||||
});
|
||||
|
||||
(issuer_url, server)
|
||||
}
|
||||
|
||||
async fn read_http_request(stream: &mut TcpStream) -> (String, String) {
|
||||
let mut buffer = Vec::new();
|
||||
let mut header_end = None;
|
||||
|
||||
while header_end.is_none() {
|
||||
let mut chunk = [0; 1024];
|
||||
let read = stream.read(&mut chunk).await.unwrap();
|
||||
assert_ne!(read, 0, "connection closed before request headers");
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4);
|
||||
}
|
||||
|
||||
let header_end = header_end.unwrap();
|
||||
let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
|
||||
let request_line = headers.lines().next().unwrap_or_default().to_string();
|
||||
let content_length = headers
|
||||
.lines()
|
||||
.find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
name.eq_ignore_ascii_case("content-length")
|
||||
.then(|| value.trim().parse::<usize>().ok())
|
||||
.flatten()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
while buffer.len() < header_end + content_length {
|
||||
let mut chunk = [0; 1024];
|
||||
let read = stream.read(&mut chunk).await.unwrap();
|
||||
assert_ne!(read, 0, "connection closed before request body");
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
let body =
|
||||
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
|
||||
|
||||
(request_line, body)
|
||||
}
|
||||
|
||||
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
haystack
|
||||
.windows(needle.len())
|
||||
.position(|window| window == needle)
|
||||
}
|
||||
|
||||
async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) {
|
||||
let response = format!(
|
||||
"HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream.write_all(response.as_bytes()).await.unwrap();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user