From 956a8ee7145f63abc2263ba0e8df62b791feb603 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sat, 27 Jun 2026 00:01:11 -0700 Subject: [PATCH] feat(python): expose OAuth connection config --- python/python/lancedb/__init__.py | 8 +++ python/python/lancedb/_lancedb.pyi | 1 + python/python/lancedb/remote/__init__.py | 3 + python/python/lancedb/remote/oauth.py | 75 ++++++++++++++++++++++++ python/src/connection.rs | 8 ++- python/src/lib.rs | 1 + python/src/oauth.rs | 72 +++++++++++++++++++++++ python/tests/test_oauth.py | 33 +++++++++++ 8 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 python/python/lancedb/remote/oauth.py create mode 100644 python/src/oauth.rs create mode 100644 python/tests/test_oauth.py diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index e748e1402..5fa700156 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -89,6 +89,8 @@ def connect( If presented, connect to LanceDB cloud. Otherwise, connect to a database on file system or cloud storage. Can be set via environment variable `LANCEDB_API_KEY`. + OAuth configuration is currently supported only by ``connect_async``; + synchronous LanceDB Cloud connections require an API key. region: str, default "us-east-1" The region to use for LanceDB Cloud. host_override: str, optional @@ -340,6 +342,7 @@ async def connect_async( session: Optional[Session] = None, manifest_enabled: bool = False, namespace_client_properties: Optional[Dict[str, str]] = None, + oauth_config=None, ) -> AsyncConnection: """Connect to a LanceDB database. @@ -389,6 +392,10 @@ async def connect_async( namespace_client_properties : dict, optional Additional directory namespace client properties to use with ``manifest_enabled=True``. + oauth_config : OAuthConfig, optional + OAuth configuration for LanceDB Cloud/Enterprise. This is supported by + ``connect_async`` only; synchronous ``connect`` uses API key + authentication for ``db://`` URIs. Examples -------- @@ -435,6 +442,7 @@ async def connect_async( session, manifest_enabled, namespace_client_properties, + oauth_config, ) ) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 8ddb28604..3f3c986dd 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -280,6 +280,7 @@ async def connect( session: Optional[Session], manifest_enabled: bool = False, namespace_client_properties: Optional[Dict[str, str]] = None, + oauth_config: Optional[Any] = None, ) -> Connection: ... class RecordBatchStream: diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 289e28942..a6ef55eb5 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -9,6 +9,7 @@ from typing import List, Optional from lancedb import __version__ from .header import HeaderProvider +from .oauth import OAuthConfig, OAuthFlowType __all__ = [ "TimeoutConfig", @@ -16,6 +17,8 @@ __all__ = [ "TlsConfig", "ClientConfig", "HeaderProvider", + "OAuthConfig", + "OAuthFlowType", ] diff --git a/python/python/lancedb/remote/oauth.py b/python/python/lancedb/remote/oauth.py new file mode 100644 index 000000000..9175c3614 --- /dev/null +++ b/python/python/lancedb/remote/oauth.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + + +class OAuthFlowType(str, Enum): + """OAuth authentication flow types.""" + + CLIENT_CREDENTIALS = "client_credentials" + """Client Credentials grant (service-to-service / M2M).""" + + AZURE_MANAGED_IDENTITY = "azure_managed_identity" + """Azure Managed Identity via IMDS.""" + + +@dataclass +class OAuthConfig: + """OAuth configuration for LanceDB authentication. + + All token acquisition and refresh is handled in the Rust layer. + This config is passed through to Rust via PyO3. + + Parameters + ---------- + issuer_url : str + OIDC issuer URL or OAuth authority URL. + For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0`` + client_id : str + Application / Client ID. + scopes : List[str] + OAuth scopes to request. + For Azure managed identity, exactly one scope or resource is required. + For example: ``["api://{app_id}/.default"]`` + flow : OAuthFlowType + Authentication flow to use. Default: CLIENT_CREDENTIALS. + client_secret : Optional[str] + Client secret (required for CLIENT_CREDENTIALS). + managed_identity_client_id : Optional[str] + Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY). + refresh_buffer_secs : Optional[int] + Seconds before expiry to trigger proactive refresh (default: 300). + Keep this well below the token TTL; if it is greater than or equal to + the TTL, each request refreshes the token. + + Examples + -------- + Client Credentials (service-to-service): + + >>> config = OAuthConfig( + ... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0", + ... client_id="app-id", + ... client_secret="secret", + ... scopes=["api://lancedb-api/.default"], + ... ) + + Azure Managed Identity: + + >>> config = OAuthConfig( + ... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0", + ... client_id="app-id", + ... scopes=["api://lancedb-api/.default"], + ... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY, + ... ) + """ + + issuer_url: str + client_id: str + scopes: List[str] + flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS + client_secret: Optional[str] = field(default=None, repr=False) + managed_identity_client_id: Optional[str] = None + refresh_buffer_secs: Optional[int] = None diff --git a/python/src/connection.rs b/python/src/connection.rs index 007480326..1bba3cefc 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -539,7 +539,7 @@ impl Connection { } #[pyfunction] -#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))] +#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))] #[allow(clippy::too_many_arguments)] pub fn connect( py: Python<'_>, @@ -553,6 +553,7 @@ pub fn connect( session: Option, manifest_enabled: bool, namespace_client_properties: Option>, + oauth_config: Option, ) -> PyResult> { future_into_py(py, async move { let mut builder = lancedb::connect(&uri); @@ -582,6 +583,11 @@ pub fn connect( if let Some(client_config) = client_config { builder = builder.client_config(client_config.into()); } + if let Some(oauth_config) = oauth_config { + let config: lancedb::remote::oauth::OAuthConfig = + oauth_config.try_into().infer_error()?; + builder = builder.oauth_config(config); + } if let Some(session) = session { builder = builder.session(session.inner.clone()); } diff --git a/python/src/lib.rs b/python/src/lib.rs index fdf8f5cb7..72043c484 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -26,6 +26,7 @@ pub mod expr; pub mod header; pub mod index; pub mod namespace; +pub mod oauth; pub mod permutation; pub mod query; pub mod runtime; diff --git a/python/src/oauth.rs b/python/src/oauth.rs new file mode 100644 index 000000000..11ea011e2 --- /dev/null +++ b/python/src/oauth.rs @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use pyo3::FromPyObject; + +use lancedb::error::Error; +use lancedb::remote::oauth::{OAuthConfig, OAuthFlow}; + +/// Python-side OAuth configuration, extracted via FromPyObject. +/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass. +#[derive(FromPyObject)] +pub struct PyOAuthConfig { + pub issuer_url: String, + pub client_id: String, + pub scopes: Vec, + pub flow: String, + pub client_secret: Option, + pub managed_identity_client_id: Option, + pub refresh_buffer_secs: Option, +} + +impl TryFrom for OAuthConfig { + type Error = Error; + + fn try_from(py: PyOAuthConfig) -> Result { + let flow = match py.flow.as_str() { + "client_credentials" => OAuthFlow::ClientCredentials, + "azure_managed_identity" => OAuthFlow::AzureManagedIdentity { + client_id: py.managed_identity_client_id, + }, + other => { + return Err(Error::InvalidInput { + message: format!("Unknown OAuth flow type: {other}"), + }); + } + }; + + Ok(Self { + issuer_url: py.issuer_url, + client_id: py.client_id, + client_secret: py.client_secret, + scopes: py.scopes, + flow, + refresh_buffer_secs: py.refresh_buffer_secs, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_unknown_oauth_flow_returns_invalid_input() { + let config = PyOAuthConfig { + issuer_url: "https://issuer.example.com".to_string(), + client_id: "client-id".to_string(), + scopes: vec!["scope".to_string()], + flow: "typo".to_string(), + client_secret: None, + managed_identity_client_id: None, + refresh_buffer_secs: None, + }; + + let err = OAuthConfig::try_from(config).unwrap_err(); + assert!(matches!( + err, + Error::InvalidInput { message } + if message == "Unknown OAuth flow type: typo" + )); + } +} diff --git a/python/tests/test_oauth.py b/python/tests/test_oauth.py new file mode 100644 index 000000000..89f5b3f8d --- /dev/null +++ b/python/tests/test_oauth.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import importlib.util +import sys +from pathlib import Path + + +def _load_oauth_module(): + oauth_path = ( + Path(__file__).parents[1] / "python" / "lancedb" / "remote" / "oauth.py" + ) + spec = importlib.util.spec_from_file_location("lancedb_remote_oauth", oauth_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_oauth_config_repr_redacts_client_secret(): + oauth = _load_oauth_module() + + config = oauth.OAuthConfig( + issuer_url="https://issuer.example.com", + client_id="client-id", + scopes=["scope"], + client_secret="super-secret", + ) + + rendered = repr(config) + assert "super-secret" not in rendered + assert "client_secret" not in rendered