diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index f58aba4ea..9cde8e8f4 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -921,8 +921,15 @@ pub struct ConnectBuilder { } #[cfg(feature = "remote")] -const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] = - [("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")]; +const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 4] = [ + ("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name"), + ("AZURE_TENANT_ID", "azure_tenant_id"), + ("AZURE_CLIENT_ID", "azure_client_id"), + ( + "AZURE_FEDERATED_TOKEN_FILE", + "azure_federated_token_file", + ), +]; impl ConnectBuilder { /// Create a new [`ConnectOptions`] with the given database URI. diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 46e51569a..aca7d3790 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -438,21 +438,26 @@ impl RestfulLanceDbClient { ); } - if let Some(v) = options.0.get("account_name") { - headers.insert( - HeaderName::from_static("x-azure-storage-account-name"), - HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { - message: format!("non-ascii storage account name '{}' provided", db_name), - })?, - ); - } - if let Some(v) = options.0.get("azure_storage_account_name") { - headers.insert( - HeaderName::from_static("x-azure-storage-account-name"), - HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { - message: format!("non-ascii storage account name '{}' provided", db_name), - })?, - ); + // Map storage options to HTTP headers for Azure configuration. + const OPTION_TO_HEADER: &[(&str, &str)] = &[ + ("account_name", "x-azure-storage-account-name"), + ("azure_storage_account_name", "x-azure-storage-account-name"), + ("azure_tenant_id", "x-azure-tenant-id"), + ("azure_client_id", "x-azure-client-id"), + ( + "azure_federated_token_file", + "x-azure-federated-token-file", + ), + ]; + for (opt_key, header_name) in OPTION_TO_HEADER { + if let Some(v) = options.get(opt_key) { + headers.insert( + HeaderName::from_static(header_name), + HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { + message: format!("non-ascii value for '{}' provided", opt_key), + })?, + ); + } } for (key, value) in &config.extra_headers { diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 66736a872..2bf172139 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -792,11 +792,21 @@ impl RemoteOptions { pub fn new(options: HashMap) -> Self { Self(options) } + + pub fn get(&self, key: &str) -> Option<&String> { + self.0.get(key) + } } impl From for RemoteOptions { fn from(options: StorageOptions) -> Self { - let supported_opts = vec!["account_name", "azure_storage_account_name"]; + let supported_opts = vec![ + "account_name", + "azure_storage_account_name", + "azure_tenant_id", + "azure_client_id", + "azure_federated_token_file", + ]; let mut filtered = HashMap::new(); for opt in supported_opts { if let Some(v) = options.0.get(opt) {