Compare commits

..

1 Commits

Author SHA1 Message Date
Lance Release
a038214a93 Bump version: 0.31.0-beta.2 → 0.31.0-beta.3 2026-06-25 01:54:32 +00:00
22 changed files with 43 additions and 1785 deletions

7
Cargo.lock generated
View File

@@ -5307,7 +5307,7 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.31.0-beta.3"
version = "0.31.0-beta.2"
dependencies = [
"ahash",
"anyhow",
@@ -5384,14 +5384,13 @@ dependencies = [
"tokenizers",
"tokio",
"url",
"urlencoding",
"uuid",
"walkdir",
]
[[package]]
name = "lancedb-nodejs"
version = "0.31.0-beta.3"
version = "0.31.0-beta.2"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5416,7 +5415,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.34.0-beta.3"
version = "0.34.0-beta.2"
dependencies = [
"arrow",
"async-trait",

View File

@@ -1,29 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthFlowType
# Enumeration: OAuthFlowType
OAuth authentication flow types.
## Enumeration Members
### AzureManagedIdentity
```ts
AzureManagedIdentity: "azure_managed_identity";
```
Azure Managed Identity via IMDS.
***
### ClientCredentials
```ts
ClientCredentials: "client_credentials";
```
Client Credentials grant (service-to-service / M2M).

View File

@@ -12,7 +12,6 @@
## Enumerations
- [FullTextQueryType](enumerations/FullTextQueryType.md)
- [OAuthFlowType](enumerations/OAuthFlowType.md)
- [Occur](enumerations/Occur.md)
- [Operator](enumerations/Operator.md)
@@ -86,8 +85,6 @@
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
- [MergeResult](interfaces/MergeResult.md)
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
- [OAuthConfig](interfaces/OAuthConfig.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
- [OptimizeStats](interfaces/OptimizeStats.md)

View File

@@ -64,19 +64,6 @@ client used by manifest-enabled native connections.
***
### oauthConfig?
```ts
optional oauthConfig: NativeOAuthConfig;
```
(For LanceDB cloud only): OAuth configuration for IdP-based
authentication (e.g., Azure Entra ID). When set, token acquisition
and refresh are handled entirely in Rust. TypeScript users should pass
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
***
### readConsistencyInterval?
```ts

View File

@@ -1,88 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / NativeOAuthConfig
# Interface: NativeOAuthConfig
OAuth configuration for LanceDB authentication.
This is the generated napi-rs binding shape. TypeScript users should prefer
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
All token acquisition and refresh is handled in the Rust layer.
## Properties
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for client_credentials).
***
### flow?
```ts
optional flow: string;
```
Authentication flow: "client_credentials" or "azure_managed_identity"
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (azure_managed_identity).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
Keep this well below the token TTL; if it is greater than or equal to
the TTL, each request refreshes the token.
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request. For Azure managed identity, exactly one scope
or resource is required. For example: `["api://{app_id}/.default"]`

View File

@@ -1,111 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthConfig
# Interface: OAuthConfig
OAuth configuration for LanceDB authentication.
This is the public TypeScript OAuth configuration type. The generated
`NativeOAuthConfig` type has the same runtime shape but is an implementation
detail of the napi-rs binding.
All token acquisition and refresh is handled in the Rust layer.
This config is passed through to Rust via napi-rs.
## Examples
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
clientSecret: "secret",
scopes: ["api://lancedb-api/.default"],
};
```
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
scopes: ["api://lancedb-api/.default"],
flow: OAuthFlowType.AzureManagedIdentity,
};
```
## Properties
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for ClientCredentials).
***
### flow?
```ts
optional flow: OAuthFlowType;
```
Authentication flow (default: ClientCredentials).
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (AzureManagedIdentity).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
Keep this well below the token TTL; if it is greater than or equal to
the TTL, each request refreshes the token.
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request.
For Azure managed identity, exactly one scope or resource is required.
For example: `["api://{app_id}/.default"]`

View File

@@ -52,7 +52,6 @@ export {
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
OAuthConfig as NativeOAuthConfig,
} from "./native.js";
export {
@@ -131,8 +130,6 @@ export {
TokenResponse,
} from "./header";
export { OAuthConfig, OAuthFlowType } from "./oauth";
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";

View File

@@ -1,76 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
/**
* OAuth authentication flow types.
*/
export enum OAuthFlowType {
/** Client Credentials grant (service-to-service / M2M). */
ClientCredentials = "client_credentials",
/** Azure Managed Identity via IMDS. */
AzureManagedIdentity = "azure_managed_identity",
}
/**
* OAuth configuration for LanceDB authentication.
*
* This is the public TypeScript OAuth configuration type. The generated
* `NativeOAuthConfig` type has the same runtime shape but is an implementation
* detail of the napi-rs binding.
*
* All token acquisition and refresh is handled in the Rust layer.
* This config is passed through to Rust via napi-rs.
*
* @example Client Credentials (service-to-service):
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* clientSecret: "secret",
* scopes: ["api://lancedb-api/.default"],
* };
* ```
*
* @example Azure Managed Identity:
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* scopes: ["api://lancedb-api/.default"],
* flow: OAuthFlowType.AzureManagedIdentity,
* };
* ```
*/
export interface OAuthConfig {
/**
* OIDC issuer URL or OAuth authority URL.
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
*/
issuerUrl: string;
/** Application / Client ID. */
clientId: string;
/**
* OAuth scopes to request.
* For Azure managed identity, exactly one scope or resource is required.
* For example: `["api://{app_id}/.default"]`
*/
scopes: string[];
/** Authentication flow (default: ClientCredentials). */
flow?: OAuthFlowType;
/** Client secret (required for ClientCredentials). */
clientSecret?: string;
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
managedIdentityClientId?: string;
/**
* Seconds before expiry to trigger proactive refresh (default: 300).
* Keep this well below the token TTL; if it is greater than or equal to
* the TTL, each request refreshes the token.
*/
refreshBufferSecs?: number;
}

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.31.0-beta.3",
"version": "0.31.0-beta.2",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.31.0-beta.3",
"version": "0.31.0-beta.2",
"cpu": [
"x64",
"arm64"

View File

@@ -112,12 +112,6 @@ impl Connection {
builder = builder.client_config(rust_config);
if let Some(oauth_config) = options.oauth_config {
let config: lancedb::remote::oauth::OAuthConfig =
oauth_config.try_into().default_error()?;
builder = builder.oauth_config(config);
}
if let Some(api_key) = options.api_key {
builder = builder.api_key(&api_key);
}

View File

@@ -65,11 +65,6 @@ pub struct ConnectionOptions {
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
/// for testing purposes.
pub host_override: Option<String>,
/// (For LanceDB cloud only): OAuth configuration for IdP-based
/// authentication (e.g., Azure Entra ID). When set, token acquisition
/// and refresh are handled entirely in Rust. TypeScript users should pass
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
pub oauth_config: Option<remote::OAuthConfig>,
}
#[napi(object)]

View File

@@ -3,7 +3,6 @@
use std::collections::HashMap;
use lancedb::error::Error;
use napi_derive::*;
/// Timeout configuration for remote HTTP client.
@@ -141,84 +140,6 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
}
}
/// OAuth configuration for LanceDB authentication.
///
/// This is the generated napi-rs binding shape. TypeScript users should prefer
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
///
/// All token acquisition and refresh is handled in the Rust layer.
#[napi(object)]
#[derive(Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// OAuth scopes to request. For Azure managed identity, exactly one scope
/// or resource is required. For example: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow: "client_credentials" or "azure_managed_identity"
pub flow: Option<String>,
/// Client secret (required for client_credentials).
pub client_secret: Option<String>,
/// Client ID for user-assigned managed identity (azure_managed_identity).
pub managed_identity_client_id: Option<String>,
/// Seconds before expiry to trigger proactive refresh (default: 300).
/// Keep this well below the token TTL; if it is greater than or equal to
/// the TTL, each request refreshes the token.
pub refresh_buffer_secs: Option<u32>,
}
impl std::fmt::Debug for OAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthConfig")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field("scopes", &self.scopes)
.field("flow", &self.flow)
.field(
"client_secret",
&self.client_secret.as_deref().map(|_| "<redacted>"),
)
.field(
"managed_identity_client_id",
&self.managed_identity_client_id,
)
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
.finish()
}
}
impl TryFrom<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
type Error = Error;
fn try_from(config: OAuthConfig) -> Result<Self, Self::Error> {
use lancedb::remote::oauth::OAuthFlow;
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
"client_credentials" => OAuthFlow::ClientCredentials,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: config.managed_identity_client_id,
},
other => {
return Err(Error::InvalidInput {
message: format!("Unknown OAuth flow type: {other}"),
});
}
};
Ok(Self {
issuer_url: config.issuer_url,
client_id: config.client_id,
client_secret: config.client_secret,
scopes: config.scopes,
flow,
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
})
}
}
impl From<ClientConfig> for lancedb::remote::ClientConfig {
fn from(config: ClientConfig) -> Self {
Self {
@@ -235,45 +156,3 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unknown_oauth_flow_returns_invalid_input() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
scopes: vec!["scope".to_string()],
flow: Some("typo".to_string()),
client_secret: None,
managed_identity_client_id: None,
refresh_buffer_secs: None,
};
let err = lancedb::remote::oauth::OAuthConfig::try_from(config).unwrap_err();
assert!(matches!(
err,
Error::InvalidInput { message }
if message == "Unknown OAuth flow type: typo"
));
}
#[test]
fn test_oauth_config_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
scopes: vec!["scope".to_string()],
flow: Some("client_credentials".to_string()),
client_secret: Some("super-secret".to_string()),
managed_identity_client_id: None,
refresh_buffer_secs: None,
};
let debug = format!("{config:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
}
}

View File

@@ -373,19 +373,6 @@ def _convert_pyarrow_schema_to_json(schema: pa.Schema) -> JsonArrowSchema:
return JsonArrowSchema(fields=fields, metadata=meta)
def _builds_namespace_natively(
namespace_client_impl: Optional[str],
namespace_client_properties: Optional[Dict[str, str]],
) -> bool:
"""Whether ``connect_namespace_client`` builds the namespace client natively
in Rust (installing the read-freshness context provider) rather than wrapping
the pre-built Python client.
Must mirror Rust ``build_namespace_natively`` in ``python/src/connection.rs``.
"""
return namespace_client_impl == "rest" and bool(namespace_client_properties)
class LanceNamespaceDBConnection(DBConnection):
"""
A LanceDB connection that uses a namespace for table management.
@@ -445,13 +432,6 @@ class LanceNamespaceDBConnection(DBConnection):
)
self._namespace_client_impl = namespace_client_impl
self._namespace_client_properties = namespace_client_properties
# When the namespace client is built natively (see Rust
# ``build_namespace_natively``), the underlying Rust table performs
# QueryTable pushdown through the read-freshness context provider, which
# the pure-Python ``query_table`` path bypasses.
self._route_pushdown_to_rust = _builds_namespace_natively(
namespace_client_impl, namespace_client_properties
)
self._inner = AsyncConnection(
_connect_namespace_client(
namespace_client,
@@ -563,7 +543,6 @@ class LanceNamespaceDBConnection(DBConnection):
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
_async=async_table,
)
@@ -601,7 +580,6 @@ class LanceNamespaceDBConnection(DBConnection):
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
_async=async_table,
)
if branch is not None:
@@ -897,8 +875,6 @@ class AsyncLanceNamespaceDBConnection:
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
namespace_client_pushdown_operations: Optional[List[str]] = None,
namespace_client_impl: Optional[str] = None,
namespace_client_properties: Optional[Dict[str, str]] = None,
):
"""
Initialize an async namespace-based LanceDB connection.
@@ -924,12 +900,6 @@ class AsyncLanceNamespaceDBConnection:
namespace.create_table() instead of using declare_table + local write.
Default is None (no pushdown, all operations run locally).
namespace_client_impl : Optional[str]
The namespace implementation name used to create this connection.
Required (with ``namespace_client_properties``) for the Rust client to
be built natively and install the read-freshness provider.
namespace_client_properties : Optional[Dict[str, str]]
The namespace properties used to create this connection.
"""
self._namespace_client = namespace_client
self.read_consistency_interval = read_consistency_interval
@@ -938,14 +908,6 @@ class AsyncLanceNamespaceDBConnection:
self._namespace_client_pushdown_operations = set(
namespace_client_pushdown_operations or []
)
self._namespace_client_impl = namespace_client_impl
self._namespace_client_properties = namespace_client_properties
# See LanceNamespaceDBConnection: when built natively the Rust table runs
# QueryTable pushdown through the read-freshness provider, so defer to it
# rather than the urllib3 client (which omits x-lancedb-min-timestamp).
self._route_pushdown_to_rust = _builds_namespace_natively(
namespace_client_impl, namespace_client_properties
)
self._inner = AsyncConnection(
_connect_namespace_client(
namespace_client,
@@ -959,8 +921,8 @@ class AsyncLanceNamespaceDBConnection:
namespace_client_pushdown_operations=(
list(self._namespace_client_pushdown_operations)
),
namespace_client_impl=namespace_client_impl,
namespace_client_properties=namespace_client_properties,
namespace_client_impl=None,
namespace_client_properties=None,
)
)
@@ -1030,7 +992,6 @@ class AsyncLanceNamespaceDBConnection:
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
)
async def open_table(
@@ -1068,7 +1029,6 @@ class AsyncLanceNamespaceDBConnection:
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
)
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
@@ -1427,6 +1387,4 @@ def connect_namespace_async(
storage_options=storage_options,
session=session,
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
namespace_client_impl=namespace_client_impl,
namespace_client_properties=namespace_client_properties,
)

View File

@@ -2022,7 +2022,6 @@ class LanceTable(Table):
namespace_client: Optional[Any] = None,
managed_versioning: Optional[bool] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
_async: AsyncTable = None,
):
if namespace_path is None:
@@ -2032,14 +2031,6 @@ class LanceTable(Table):
self._location = location # Store location for use in _dataset_path
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
# When the connection built the namespace client natively (e.g. an
# enterprise "rest" connection), the underlying Rust table already
# executes QueryTable pushdown itself -- and, unlike this Python urllib3
# path, it routes through the read-freshness context provider that emits
# the ``x-lancedb-min-timestamp`` header. So we must defer pushdown to
# Rust instead of calling the Python ``namespace_client.query_table``
# directly, or reads silently bypass read-freshness (stale results).
self._route_pushdown_to_rust = route_pushdown_to_rust
if _async is not None:
self._table = _async
else:
@@ -2250,7 +2241,6 @@ class LanceTable(Table):
namespace_path=self._namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
location=self._location,
_async=async_table,
)
@@ -2401,11 +2391,8 @@ class LanceTable(Table):
Returns
-------
pa.Table"""
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
):
return self._execute_query(Query()).read_all()
@@ -3357,7 +3344,6 @@ class LanceTable(Table):
location: Optional[str] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
):
"""
Create a new table.
@@ -3420,7 +3406,6 @@ class LanceTable(Table):
self._location = location
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._route_pushdown_to_rust = route_pushdown_to_rust
if data_storage_version is not None:
warnings.warn(
@@ -3534,7 +3519,6 @@ class LanceTable(Table):
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
and self.current_branch() is None
):
from lancedb.namespace import _execute_server_side_query
@@ -4276,7 +4260,6 @@ class AsyncTable:
namespace_path: Optional[List[str]] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
):
"""Create a new AsyncTable object.
@@ -4289,9 +4272,6 @@ class AsyncTable:
self._namespace_path = namespace_path or []
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
# See LanceTable.__init__: defer QueryTable pushdown to Rust (which emits
# the read-freshness header) for natively-built namespace clients.
self._route_pushdown_to_rust = route_pushdown_to_rust
def _set_namespace_context(
self,
@@ -4299,12 +4279,10 @@ class AsyncTable:
namespace_path: Optional[List[str]] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
) -> "AsyncTable":
self._namespace_path = namespace_path or []
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._route_pushdown_to_rust = route_pushdown_to_rust
return self
def __repr__(self):
@@ -4514,11 +4492,8 @@ class AsyncTable:
-------
pa.Table
"""
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
):
return (await self._execute_query(Query())).read_all()
@@ -5202,11 +5177,8 @@ class AsyncTable:
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
):
from lancedb.namespace import _execute_server_side_query

View File

@@ -65,9 +65,6 @@ def _namespace_lance_table(namespace_client: _NamespaceClient) -> LanceTable:
table._namespace_path = ["geneva"]
table._namespace_client = namespace_client
table._pushdown_operations = {"QueryTable"}
# This test exercises the Python-side pushdown path (non-native client), so
# pushdown is not routed to Rust.
table._route_pushdown_to_rust = False
return table
@@ -808,37 +805,6 @@ class TestPushdownOperations:
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
assert len(db._namespace_client_pushdown_operations) == 0
def test_route_pushdown_to_rust_for_native_rest(self):
"""A natively-built rest connection must defer QueryTable pushdown to
Rust so reads carry the x-lancedb-min-timestamp read-freshness header."""
db = lancedb.connect_namespace(
"rest",
{"uri": "http://localhost:12345"},
namespace_client_pushdown_operations=["QueryTable"],
)
assert db._route_pushdown_to_rust is True
def test_route_pushdown_to_rust_false_for_dir(self):
"""A non-native (dir) connection keeps the Python pushdown path."""
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
assert db._route_pushdown_to_rust is False
def test_async_route_pushdown_to_rust_for_native_rest(self):
"""The async connection must not silently bypass the read-freshness fix:
a natively-built rest connection defers pushdown to Rust (regression test
for the async path omitting the freshness header)."""
db = lancedb.connect_namespace_async(
"rest",
{"uri": "http://localhost:12345"},
namespace_client_pushdown_operations=["QueryTable"],
)
assert db._route_pushdown_to_rust is True
def test_async_route_pushdown_to_rust_false_for_dir(self):
"""The async non-native (dir) connection keeps the Python pushdown path."""
db = lancedb.connect_namespace_async("dir", {"root": self.temp_dir})
assert db._route_pushdown_to_rust is False
def test_lance_table_to_arrow_uses_query_pushdown(self):
namespace_client = _NamespaceClient()
table = _namespace_lance_table(namespace_client)

View File

@@ -610,38 +610,24 @@ pub fn connect_namespace_client(
namespace_client_impl: Option<String>,
namespace_client_properties: Option<HashMap<String, String>>,
) -> PyResult<Connection> {
let namespace_client = extract_namespace_arc(py, namespace_client)?;
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
let namespace_client_pushdown_operations =
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string());
let ns_properties = namespace_client_properties.unwrap_or_default();
let storage_options = storage_options.unwrap_or_default();
let session = session.map(|s| s.inner.clone());
// Prefer building the namespace natively from (impl, properties) so the
// read-freshness provider installed
let database = if build_namespace_natively(namespace_client_impl.as_deref(), &ns_properties) {
let ns_impl = namespace_client_impl.expect("impl present per build_namespace_natively");
crate::runtime::block_on(LanceNamespaceDatabase::connect(
&ns_impl,
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
))
.infer_error()?
} else {
let namespace_client = extract_namespace_arc(py, namespace_client)?;
LanceNamespaceDatabase::from_namespace_client(
namespace_client,
namespace_client_impl.unwrap_or_else(|| "python".to_string()),
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
)
};
let database = LanceNamespaceDatabase::from_namespace_client(
namespace_client,
ns_impl,
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
);
Ok(Connection::new(LanceConnection::new(
Arc::new(database),
@@ -649,16 +635,6 @@ pub fn connect_namespace_client(
)))
}
/// Whether to build the namespace natively (from impl + properties) instead of
/// wrapping a pre-built client. Native construction is required for the
/// read-freshness provider to be installed
fn build_namespace_natively(
namespace_client_impl: Option<&str>,
namespace_client_properties: &HashMap<String, String>,
) -> bool {
matches!(namespace_client_impl, Some("rest")) && !namespace_client_properties.is_empty()
}
#[derive(FromPyObject)]
pub struct PyClientConfig {
user_agent: String,
@@ -757,36 +733,3 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn props(pairs: &[(&str, &str)]) -> HashMap<String, String> {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn native_build_only_for_rest_with_properties() {
let rest = props(&[("uri", "http://localhost:10024")]);
// rest + non-empty properties -> build natively (installs the
// read-freshness provider so checkout_latest() busts the server cache).
assert!(build_namespace_natively(Some("rest"), &rest));
// dir is local (no server cache) -> wrap the pre-built client unchanged.
assert!(!build_namespace_natively(
Some("dir"),
&props(&[("root", "/tmp")])
));
// No impl: only a pre-built client was handed in -> wrap it as-is.
assert!(!build_namespace_natively(None, &rest));
// rest but no properties: nothing to build a connection from -> wrap.
assert!(!build_namespace_natively(Some("rest"), &HashMap::new()));
}
}

View File

@@ -56,15 +56,6 @@ fn get_runtime() -> &'static runtime::Runtime {
unsafe { &*new_ptr }
}
/// Block the current thread on a future using the shared runtime.
///
/// For sync `#[pyfunction]`s that need to drive an async operation (e.g.
/// building a namespace client). Must not be called from within the runtime's
/// own worker threads.
pub fn block_on<F: std::future::Future>(fut: F) -> F::Output {
get_runtime().block_on(fut)
}
/// Runs in async-signal context after `fork()` in the child. We can only
/// touch atomics here; we deliberately leak the previous runtime because
/// dropping a tokio `Runtime` would try to join its (now-dead) worker

View File

@@ -50,7 +50,7 @@ lance-namespace = { workspace = true }
lance-namespace-impls = { workspace = true }
moka = { workspace = true }
pin-project = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread", "sync"] }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
log.workspace = true
async-trait = "0"
bytes = "1"
@@ -75,7 +75,6 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"stream",
], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
urlencoding = { version = "2", optional = true }
uuid = { version = "1.7.0", features = ["v4", "v5"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
@@ -94,7 +93,6 @@ semver = { workspace = true }
anyhow = "1"
tempfile = "3.5.0"
random_word = { version = "0.4.3", features = ["en"] }
tokio = { version = "1.23", features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
aws-sdk-dynamodb = { version = "1.55.0" }
@@ -131,13 +129,7 @@ huggingface = [
"lance-namespace-impls/dir-huggingface",
]
dynamodb = ["lance/dynamodb", "aws"]
remote = [
"dep:reqwest",
"dep:http",
"dep:urlencoding",
"lance-namespace-impls/rest",
"lance-namespace-impls/rest-adapter",
]
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -667,8 +667,6 @@ pub struct ConnectRequest {
pub struct ConnectBuilder {
request: ConnectRequest,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
#[cfg(feature = "remote")]
oauth_config: Option<crate::remote::OAuthConfig>,
}
#[cfg(feature = "remote")]
@@ -690,8 +688,6 @@ impl ConnectBuilder {
session: None,
},
embedding_registry: None,
#[cfg(feature = "remote")]
oauth_config: None,
}
}
@@ -780,19 +776,6 @@ impl ConnectBuilder {
self
}
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
///
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
/// from the given config and sets it as the header provider. OAuth cannot
/// be combined with an API key or another header provider.
///
/// Token acquisition and refresh are handled in Rust.
#[cfg(feature = "remote")]
pub fn oauth_config(mut self, config: crate::remote::OAuthConfig) -> Self {
self.oauth_config = Some(config);
self
}
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
self.embedding_registry = Some(registry);
@@ -938,40 +921,9 @@ impl ConnectBuilder {
let region = options.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(),
})?;
let api_key = match (&self.oauth_config, &options.api_key) {
(Some(_), None) => String::new(),
(Some(_), Some(_)) => {
return Err(Error::InvalidInput {
message:
"api_key and oauth_config cannot both be set when connecting to LanceDb Cloud"
.to_string(),
});
}
(None, Some(key)) => key.clone(),
(None, None) => {
return Err(Error::InvalidInput {
message:
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
.to_string(),
});
}
};
if self.oauth_config.is_some() && self.request.client_config.header_provider.is_some() {
return Err(Error::InvalidInput {
message:
"oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud"
.to_string(),
});
}
let mut client_config = self.request.client_config;
if let Some(oauth_config) = self.oauth_config {
let provider = crate::remote::OAuthHeaderProvider::new(oauth_config)?;
client_config.header_provider =
Some(Arc::new(provider) as Arc<dyn crate::remote::HeaderProvider>);
}
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
let storage_options = StorageOptions(options.storage_options.clone());
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
@@ -979,7 +931,7 @@ impl ConnectBuilder {
&api_key,
&region,
options.host_override,
client_config,
self.request.client_config,
storage_options.into(),
self.request.read_consistency_interval,
)?);
@@ -1288,83 +1240,6 @@ mod tests {
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
}
#[cfg(feature = "remote")]
#[tokio::test]
async fn test_connect_rejects_api_key_with_oauth_config() {
let oauth_config = crate::remote::OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: crate::remote::OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let result = ConnectBuilder::new("db://my-container/my-prefix")
.region("us-east-1")
.api_key("my-api-key")
.oauth_config(oauth_config)
.execute()
.await;
match result {
Err(Error::InvalidInput { message })
if message
== "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud" =>
{}
Err(err) => panic!("expected InvalidInput, got {err:?}"),
Ok(_) => panic!("expected api_key and oauth_config to be rejected"),
}
}
#[cfg(feature = "remote")]
#[tokio::test]
async fn test_connect_rejects_header_provider_with_oauth_config() {
#[derive(Debug)]
struct TestHeaderProvider;
#[async_trait::async_trait]
impl crate::remote::HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
Ok(HashMap::from([(
"authorization".to_string(),
"Bearer token".to_string(),
)]))
}
}
let oauth_config = crate::remote::OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: crate::remote::OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let client_config = crate::remote::ClientConfig {
header_provider: Some(
Arc::new(TestHeaderProvider) as Arc<dyn crate::remote::HeaderProvider>
),
..Default::default()
};
let result = ConnectBuilder::new("db://my-container/my-prefix")
.region("us-east-1")
.client_config(client_config)
.oauth_config(oauth_config)
.execute()
.await;
match result {
Err(Error::InvalidInput { message })
if message
== "oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud" =>
{}
Err(err) => panic!("expected InvalidInput, got {err:?}"),
Ok(_) => panic!("expected header_provider and oauth_config to be rejected"),
}
}
#[cfg(not(windows))]
#[tokio::test]
async fn test_connect_relative() {

View File

@@ -8,7 +8,6 @@
pub(crate) mod client;
pub(crate) mod db;
pub mod oauth;
mod retry;
pub(crate) mod table;
pub(crate) mod util;
@@ -21,4 +20,3 @@ const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};

View File

@@ -1,907 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use log::debug;
use reqwest::Client;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::error::{Error, Result};
use crate::remote::client::HeaderProvider;
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
const DEFAULT_TOKEN_TTL_SECS: u64 = 3600;
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
/// OAuth authentication flow configuration.
#[derive(Debug, Clone)]
pub enum OAuthFlow {
/// Client Credentials grant (service-to-service / M2M).
/// Requires `client_secret` in [`OAuthConfig`].
ClientCredentials,
/// Azure Managed Identity via IMDS.
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
/// IMDS requests bypass proxy settings because the endpoint is link-local.
AzureManagedIdentity {
/// Client ID for user-assigned managed identity.
/// Omit for system-assigned managed identity.
client_id: Option<String>,
},
}
/// OAuth configuration for LanceDB authentication.
///
/// All token acquisition and refresh is handled in the Rust layer.
/// Python and TypeScript bindings expose this as a plain config object.
#[derive(Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// Client secret (required for `ClientCredentials`, optional for others).
pub client_secret: Option<String>,
/// OAuth scopes to request.
/// For Azure managed identity, exactly one scope or resource is required.
/// For example: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow to use.
pub flow: OAuthFlow,
/// Seconds before token expiry to trigger proactive refresh (default: 300).
/// Keep this well below the token TTL; if it is greater than or equal to
/// the TTL, each request refreshes the token.
pub refresh_buffer_secs: Option<u64>,
}
impl std::fmt::Debug for OAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthConfig")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field(
"client_secret",
&self.client_secret.as_deref().map(|_| "<redacted>"),
)
.field("scopes", &self.scopes)
.field("flow", &self.flow)
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
.finish()
}
}
// -- OIDC Discovery --
#[derive(Clone, Debug, Deserialize)]
struct OidcDiscovery {
token_endpoint: String,
}
// -- Token Response --
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
/// Token lifetime in seconds.
/// Some providers (Azure IMDS) return this as a string, so we accept both.
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
expires_in: Option<u64>,
#[serde(default)]
#[allow(dead_code)]
token_type: Option<String>,
}
impl std::fmt::Debug for TokenResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenResponse")
.field("access_token", &"<redacted>")
.field("expires_in", &self.expires_in)
.field("token_type", &self.token_type)
.finish()
}
}
fn deserialize_optional_u64_or_string<'de, D>(
deserializer: D,
) -> std::result::Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct U64OrString;
impl<'de> de::Visitor<'de> for U64OrString {
type Value = Option<u64>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an integer, an integer-valued float, a numeric string, or null")
}
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(Some(v))
}
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
if v < 0 {
return Err(E::custom(format!("invalid expires_in value: {v}")));
}
Ok(Some(v as u64))
}
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
if !v.is_finite() || v < 0.0 || v.fract() != 0.0 || v > u64::MAX as f64 {
return Err(E::custom(format!("invalid expires_in value: {v}")));
}
Ok(Some(v as u64))
}
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
v.parse::<u64>().map(Some).map_err(de::Error::custom)
}
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
}
deserializer.deserialize_any(U64OrString)
}
// -- Internal Token State --
struct TokenState {
access_token: Option<String>,
expires_at: Option<Instant>,
}
impl TokenState {
fn new() -> Self {
Self {
access_token: None,
expires_at: None,
}
}
fn is_expired(&self, buffer: Duration) -> bool {
match (self.access_token.as_ref(), self.expires_at) {
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
(None, _) => true,
(Some(_), None) => true,
}
}
fn update(&mut self, resp: &TokenResponse) {
self.access_token = Some(resp.access_token.clone());
let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS);
self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in));
}
}
#[async_trait]
trait TokenSource: Send + Sync + std::fmt::Debug {
async fn fetch_token(&self) -> Result<TokenResponse>;
}
struct ClientCredentialsSource {
issuer_url: String,
client_id: String,
client_secret: String,
scopes: Vec<String>,
http_client: Client,
discovery: RwLock<Option<OidcDiscovery>>,
}
impl std::fmt::Debug for ClientCredentialsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientCredentialsSource")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field("client_secret", &"<redacted>")
.field("scopes", &self.scopes)
.finish()
}
}
impl ClientCredentialsSource {
fn new(
issuer_url: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self> {
let client_secret = client_secret.ok_or(Error::InvalidInput {
message: "client_secret is required for ClientCredentials flow".to_string(),
})?;
Self::validate_issuer_transport(&issuer_url)?;
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for OAuth: {e}"),
})?;
Ok(Self {
issuer_url,
client_id,
client_secret,
scopes,
http_client,
discovery: RwLock::new(None),
})
}
fn validate_issuer_transport(issuer_url: &str) -> Result<()> {
let issuer = url::Url::parse(issuer_url).map_err(|e| Error::InvalidInput {
message: format!("Invalid OAuth issuer_url: {e}"),
})?;
match issuer.scheme() {
"https" => Ok(()),
"http" if Self::is_loopback_issuer(&issuer) => Ok(()),
_ => Err(Error::InvalidInput {
message:
"ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
.to_string(),
}),
}
}
fn is_loopback_issuer(issuer: &url::Url) -> bool {
let Some(host) = issuer.host_str() else {
return false;
};
host.eq_ignore_ascii_case("localhost")
|| host
.parse::<IpAddr>()
.map(|addr| addr.is_loopback())
.unwrap_or(false)
}
async fn get_discovery(&self) -> Result<OidcDiscovery> {
{
let cached = self.discovery.read().await;
if let Some(ref disc) = *cached {
return Ok(disc.clone());
}
}
let mut cache = self.discovery.write().await;
// Double-check
if let Some(ref disc) = *cache {
return Ok(disc.clone());
}
let discovery_url = format!(
"{}/.well-known/openid-configuration",
self.issuer_url.trim_end_matches('/')
);
debug!("Fetching OIDC discovery from {}", discovery_url);
let resp = self
.http_client
.get(&discovery_url)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to fetch OIDC discovery document: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"OIDC discovery failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse OIDC discovery document: {e}"),
})?;
let result = disc.clone();
*cache = Some(disc);
Ok(result)
}
async fn get_token_endpoint(&self) -> Result<String> {
self.get_discovery().await.map(|disc| disc.token_endpoint)
}
fn scopes_string(&self) -> String {
self.scopes.join(" ")
}
async fn post_token_request(
&self,
endpoint: &str,
params: &[(&str, &str)],
) -> Result<TokenResponse> {
let resp = self
.http_client
.post(endpoint)
.form(params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Token request to {endpoint} failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Token request failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse token response: {e}"),
})
}
}
#[async_trait]
impl TokenSource for ClientCredentialsSource {
async fn fetch_token(&self) -> Result<TokenResponse> {
let token_endpoint = self.get_token_endpoint().await?;
let scope = self.scopes_string();
let params = [
("grant_type", "client_credentials"),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("scope", scope.as_str()),
];
self.post_token_request(&token_endpoint, &params).await
}
}
struct AzureImdsSource {
client_id: Option<String>,
resource: String,
http_client: Client,
}
impl std::fmt::Debug for AzureImdsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureImdsSource")
.field("client_id", &self.client_id)
.field("resource", &self.resource)
.finish()
}
}
impl AzureImdsSource {
fn new(scopes: Vec<String>, client_id: Option<String>) -> Result<Self> {
let resource = Self::resource_from_scopes(&scopes)?;
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.no_proxy()
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for Azure IMDS OAuth: {e}"),
})?;
Ok(Self {
client_id,
resource,
http_client,
})
}
fn resource_from_scopes(scopes: &[String]) -> Result<String> {
let [scope] = scopes else {
return Err(Error::InvalidInput {
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
.to_string(),
});
};
Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string())
}
}
#[async_trait]
impl TokenSource for AzureImdsSource {
async fn fetch_token(&self) -> Result<TokenResponse> {
let mut url = format!(
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
urlencoding::encode(&self.resource),
);
if let Some(cid) = self.client_id.as_deref() {
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
}
let resp = self
.http_client
.get(&url)
.header("Metadata", "true")
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Azure IMDS request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Azure IMDS returned status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse IMDS token response: {e}"),
})
}
}
/// OAuth header provider that manages the full token lifecycle.
///
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
/// headers into every LanceDB request, with automatic token refresh.
pub struct OAuthHeaderProvider {
token_source: Box<dyn TokenSource>,
token_state: Arc<RwLock<TokenState>>,
refresh_buffer: Duration,
}
impl std::fmt::Debug for OAuthHeaderProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthHeaderProvider")
.field("token_source", &self.token_source)
.finish()
}
}
impl OAuthHeaderProvider {
/// Create a new OAuth header provider from configuration.
pub fn new(config: OAuthConfig) -> Result<Self> {
let OAuthConfig {
issuer_url,
client_id,
client_secret,
scopes,
flow,
refresh_buffer_secs,
} = config;
if scopes.is_empty() {
return Err(Error::InvalidInput {
message: "At least one OAuth scope is required".to_string(),
});
}
let refresh_buffer =
Duration::from_secs(refresh_buffer_secs.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS));
let token_source: Box<dyn TokenSource> = match flow {
OAuthFlow::ClientCredentials => Box::new(ClientCredentialsSource::new(
issuer_url,
client_id,
client_secret,
scopes,
)?),
OAuthFlow::AzureManagedIdentity { client_id } => {
Box::new(AzureImdsSource::new(scopes, client_id)?)
}
};
Ok(Self {
token_source,
token_state: Arc::new(RwLock::new(TokenState::new())),
refresh_buffer,
})
}
/// Get a valid access token, refreshing if necessary.
async fn get_valid_token(&self) -> Result<String> {
// Fast path: check if current token is still valid
{
let state = self.token_state.read().await;
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
}
// Slow path: acquire or refresh token
let mut state = self.token_state.write().await;
// Double-check after acquiring write lock
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
debug!("Acquiring new OAuth token via {:?}", self.token_source);
let resp = self.token_source.fetch_token().await?;
state.update(&resp);
Ok(resp.access_token)
}
}
#[async_trait]
impl HeaderProvider for OAuthHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
let token = self.get_valid_token().await?;
Ok(HashMap::from([(
"authorization".to_string(),
format!("Bearer {token}"),
)]))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task::JoinHandle;
#[test]
fn test_token_state_expiry() {
let mut state = TokenState::new();
assert!(state.is_expired(Duration::from_secs(0)));
state.access_token = Some("tok".to_string());
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
assert!(!state.is_expired(Duration::from_secs(300)));
assert!(state.is_expired(Duration::from_secs(601)));
state.expires_at = None;
assert!(state.is_expired(Duration::from_secs(0)));
}
#[test]
fn test_token_state_uses_default_expiry() {
let mut state = TokenState::new();
let response = TokenResponse {
access_token: "tok".to_string(),
expires_in: None,
token_type: None,
};
state.update(&response);
assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1)));
assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1)));
}
#[test]
fn test_token_response_accepts_float_expires_in() {
let response: TokenResponse =
serde_json::from_str(r#"{"access_token":"tok","expires_in":3600.0}"#).unwrap();
assert_eq!(response.expires_in, Some(3600));
}
#[test]
fn test_token_response_rejects_negative_expires_in() {
let err =
serde_json::from_str::<TokenResponse>(r#"{"access_token":"tok","expires_in":-1}"#)
.unwrap_err();
assert!(err.to_string().contains("invalid expires_in value: -1"));
}
#[test]
fn test_token_response_debug_redacts_access_token() {
let response = TokenResponse {
access_token: "secret-token".to_string(),
expires_in: Some(3600),
token_type: Some("Bearer".to_string()),
};
let debug = format!("{response:?}");
assert!(!debug.contains("secret-token"));
assert!(debug.contains("access_token: \"<redacted>\""));
}
#[test]
fn test_scopes_string() {
let source = ClientCredentialsSource::new(
"https://login.microsoftonline.com/tenant/v2.0".to_string(),
"app-id".to_string(),
Some("secret".to_string()),
vec!["scope1".to_string(), "scope2".to_string()],
)
.unwrap();
assert_eq!(source.scopes_string(), "scope1 scope2");
}
#[test]
fn test_oauth_config_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("super-secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let debug = format!("{config:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
}
#[test]
fn test_oauth_header_provider_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("super-secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let provider = OAuthHeaderProvider::new(config).unwrap();
let debug = format!("{provider:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: \"<redacted>\""));
}
#[test]
fn test_managed_identity_resource_from_default_scope() {
assert_eq!(
AzureImdsSource::resource_from_scopes(&["api://test/.default".to_string()]).unwrap(),
"api://test"
);
}
#[test]
fn test_managed_identity_resource_without_default_suffix() {
assert_eq!(
AzureImdsSource::resource_from_scopes(&["api://test".to_string()]).unwrap(),
"api://test"
);
}
#[test]
fn test_managed_identity_rejects_multiple_scopes() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![
"api://test-a/.default".to_string(),
"api://test-b/.default".to_string(),
],
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[tokio::test]
async fn test_token_endpoint_requires_discovery_success() {
let (issuer_url, server) = spawn_discovery_error_server().await;
let source = ClientCredentialsSource::new(
issuer_url,
"client-id".to_string(),
Some("secret".to_string()),
vec!["scope".to_string()],
)
.unwrap();
let err = source.get_token_endpoint().await.unwrap_err();
assert!(matches!(
err,
Error::Runtime { message }
if message.contains("OIDC discovery failed with status 503")
));
server.await.unwrap();
}
#[test]
fn test_client_credentials_requires_secret() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[test]
fn test_client_credentials_rejects_insecure_non_loopback_issuer() {
let config = OAuthConfig {
issuer_url: "http://issuer.example.com".to_string(),
client_id: "app-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let err = OAuthHeaderProvider::new(config).unwrap_err();
assert!(matches!(
err,
Error::InvalidInput { message }
if message == "ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
));
}
#[test]
fn test_empty_scopes_rejected() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![],
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[tokio::test]
async fn test_client_credentials_token_lifecycle() {
let (issuer_url, token_requests, server) = spawn_oauth_server().await;
let config = OAuthConfig {
issuer_url,
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: Some(0),
};
let provider = OAuthHeaderProvider::new(config).unwrap();
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
provider.token_state.write().await.expires_at =
Some(Instant::now() - Duration::from_secs(1));
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2");
assert_eq!(token_requests.load(Ordering::SeqCst), 2);
server.await.unwrap();
}
async fn spawn_oauth_server() -> (String, Arc<AtomicUsize>, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let issuer_url = format!("http://{addr}");
let token_requests = Arc::new(AtomicUsize::new(0));
let server_token_requests = Arc::clone(&token_requests);
let server = tokio::spawn(async move {
for _ in 0..3 {
let (mut stream, _) = listener.accept().await.unwrap();
let (request_line, body) = read_http_request(&mut stream).await;
if request_line.starts_with("GET /.well-known/openid-configuration ") {
let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#);
write_json_response(&mut stream, "200 OK", &discovery).await;
} else if request_line.starts_with("POST /token ") {
assert!(body.contains("grant_type=client_credentials"));
assert!(body.contains("client_id=client-id"));
assert!(body.contains("client_secret=secret"));
assert!(body.contains("scope=scope"));
let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1;
let token = format!(
r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"#
);
write_json_response(&mut stream, "200 OK", &token).await;
} else {
write_json_response(&mut stream, "404 Not Found", "{}").await;
}
}
});
(issuer_url, token_requests, server)
}
async fn spawn_discovery_error_server() -> (String, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let issuer_url = format!("http://{addr}");
let server = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let (request_line, _) = read_http_request(&mut stream).await;
assert!(request_line.starts_with("GET /.well-known/openid-configuration "));
write_json_response(&mut stream, "503 Service Unavailable", "{}").await;
});
(issuer_url, server)
}
async fn read_http_request(stream: &mut TcpStream) -> (String, String) {
let mut buffer = Vec::new();
let mut header_end = None;
while header_end.is_none() {
let mut chunk = [0; 1024];
let read = stream.read(&mut chunk).await.unwrap();
assert_ne!(read, 0, "connection closed before request headers");
buffer.extend_from_slice(&chunk[..read]);
header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4);
}
let header_end = header_end.unwrap();
let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
let request_line = headers.lines().next().unwrap_or_default().to_string();
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
while buffer.len() < header_end + content_length {
let mut chunk = [0; 1024];
let read = stream.read(&mut chunk).await.unwrap();
assert_ne!(read, 0, "connection closed before request body");
buffer.extend_from_slice(&chunk[..read]);
}
let body =
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
(request_line, body)
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) {
let response = format!(
"HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).await.unwrap();
}
}

View File

@@ -579,45 +579,24 @@ fn array_to_f32_vec(arr: &Arc<dyn arrow_array::Array>) -> Result<Vec<f32>> {
})
}
/// Magic bytes that prefix (and suffix) the Arrow IPC *file* format.
const ARROW_IPC_FILE_MAGIC: &[u8] = b"ARROW1";
/// Parse Arrow IPC response from the namespace server.
///
/// The server may return either the Arrow IPC *file* format or the *stream*
/// format. REST/phalanx returns the file format (it begins with the `ARROW1`
/// magic); reading that with a `StreamReader` fails with "failed to fill whole
/// buffer". Detect the magic and pick the matching reader so both are handled.
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
use arrow_ipc::reader::{FileReader, StreamReader};
use arrow_ipc::reader::StreamReader;
use std::io::Cursor;
let (schema, batches) = if bytes.starts_with(ARROW_IPC_FILE_MAGIC) {
let reader = FileReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC file response: {}", e),
let cursor = Cursor::new(bytes);
let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC response: {}", e),
})?;
// Collect all record batches
let schema = reader.schema();
let batches: Vec<_> = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC batches: {}", e),
})?;
let schema = reader.schema();
let batches = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC file batches: {}", e),
})?;
(schema, batches)
} else {
let reader =
StreamReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC response: {}", e),
})?;
let schema = reader.schema();
let batches = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC batches: {}", e),
})?;
(schema, batches)
};
// Create a stream from the batches
let stream = futures::stream::iter(batches.into_iter().map(Ok));
@@ -645,59 +624,6 @@ mod tests {
FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap()
}
#[tokio::test]
async fn test_parse_arrow_ipc_response_handles_file_and_stream() {
use arrow_array::{Int32Array, RecordBatch};
use arrow_ipc::writer::{FileWriter, StreamWriter};
use arrow_schema::{DataType, Field, Schema};
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
)
.unwrap();
// Arrow IPC *file* format -- what REST/phalanx returns. Previously this
// failed with "failed to fill whole buffer" because we used a StreamReader.
let mut file_buf = Vec::new();
{
let mut writer = FileWriter::try_new(&mut file_buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}
assert!(file_buf.starts_with(ARROW_IPC_FILE_MAGIC));
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(file_buf))
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap()
.iter()
.map(|b| b.num_rows())
.sum();
assert_eq!(rows, 3);
// Arrow IPC *stream* format must still parse.
let mut stream_buf = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut stream_buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}
assert!(!stream_buf.starts_with(ARROW_IPC_FILE_MAGIC));
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(stream_buf))
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap()
.iter()
.map(|b| b.num_rows())
.sum();
assert_eq!(rows, 3);
}
#[test]
fn test_convert_to_namespace_query_vector() {
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));