diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 53b61641b..7fc66d813 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -576,6 +576,9 @@ impl Connection { /// For LanceNamespaceDatabase, it is the underlying LanceNamespace. /// For ListingDatabase, it is the equivalent DirectoryNamespace. /// For RemoteDatabase, it is the equivalent RestNamespace. + /// + /// Remote connections using dynamic headers forward them through the + /// namespace client's per-request context provider. pub async fn namespace_client(&self) -> Result> { self.internal.namespace_client().await } @@ -584,6 +587,9 @@ impl Connection { /// Returns (impl_type, properties) where: /// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace /// - properties: configuration properties for the namespace + /// + /// Remote connections using dynamic headers cannot be exported because the + /// namespace client config only carries static headers. pub async fn namespace_client_config( &self, ) -> Result<(String, std::collections::HashMap)> { diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 7c163e0bc..25fcd23b1 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use async_trait::async_trait; use http::StatusCode; use lance_io::object_store::StorageOptions; +use lance_namespace_impls::{DynamicContextProvider, OperationInfo}; use moka::future::Cache; use reqwest::header::CONTENT_TYPE; @@ -26,7 +27,9 @@ use crate::remote::util::stream_as_body; use crate::table::BaseTable; use super::ARROW_STREAM_CONTENT_TYPE; -use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender}; +use super::client::{ + ClientConfig, HeaderProvider, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender, +}; use super::table::RemoteTable; use super::util::parse_server_version; @@ -194,10 +197,66 @@ pub struct RemoteDatabase { uri: String, /// Headers to pass to the namespace client for authentication namespace_headers: HashMap, + namespace_context_provider: Option>, /// TLS configuration for mTLS support tls_config: Option, } +#[derive(Clone)] +struct NamespaceHeaderProviderContext { + header_provider: Arc, +} + +impl std::fmt::Debug for NamespaceHeaderProviderContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NamespaceHeaderProviderContext") + .field("header_provider", &"Some(...)") + .finish() + } +} + +impl DynamicContextProvider for NamespaceHeaderProviderContext { + fn provide_context(&self, _info: &OperationInfo) -> HashMap { + let header_provider = Arc::clone(&self.header_provider); + let handle = match std::thread::Builder::new() + .name("lancedb-namespace-headers".to_string()) + .spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|e| Error::Runtime { + message: format!( + "Failed to create runtime for namespace header provider: {e}" + ), + })? + .block_on(header_provider.get_headers()) + }) { + Ok(handle) => handle, + Err(err) => { + log::warn!("Failed to spawn dynamic namespace header provider thread: {err}"); + return HashMap::new(); + } + }; + + let headers = handle.join(); + + match headers { + Ok(Ok(headers)) => headers + .into_iter() + .map(|(key, value)| (format!("headers.{key}"), value)) + .collect(), + Ok(Err(err)) => { + log::warn!("Failed to get dynamic namespace headers: {err}"); + HashMap::new() + } + Err(_) => { + log::warn!("Dynamic namespace header provider panicked"); + HashMap::new() + } + } + } +} + impl RemoteDatabase { pub fn try_new( uri: &str, @@ -228,6 +287,16 @@ impl RemoteDatabase { }) .collect(); + let namespace_context_provider = + client_config + .header_provider + .as_ref() + .map(|header_provider| { + Arc::new(NamespaceHeaderProviderContext { + header_provider: Arc::clone(header_provider), + }) as Arc + }); + let client = RestfulLanceDbClient::try_new( &parsed, region, @@ -247,6 +316,7 @@ impl RemoteDatabase { table_cache, uri: uri.to_owned(), namespace_headers, + namespace_context_provider, tls_config: client_config.tls_config, }) } @@ -271,6 +341,7 @@ mod test_utils { table_cache: Cache::new(0), uri: "http://localhost".to_string(), namespace_headers: HashMap::new(), + namespace_context_provider: None, tls_config: None, } } @@ -281,11 +352,18 @@ mod test_utils { T: Into, { let client = client_with_handler_and_config(handler, config.clone()); + let namespace_context_provider = + config.header_provider.as_ref().map(|header_provider| { + Arc::new(NamespaceHeaderProviderContext { + header_provider: Arc::clone(header_provider), + }) as Arc + }); Self { client, table_cache: Cache::new(0), uri: "http://localhost".to_string(), namespace_headers: config.extra_headers.clone(), + namespace_context_provider, tls_config: config.tls_config.clone(), } } @@ -759,9 +837,12 @@ impl Database for RemoteDatabase { // Create a RestNamespace pointing to the same remote host with the same authentication headers let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host()) .delimiter(&self.client.id_delimiter) - // TODO: support header provider .headers(self.namespace_headers.clone()); + if let Some(context_provider) = &self.namespace_context_provider { + builder = builder.context_provider(Arc::clone(context_provider)); + } + // Apply mTLS configuration if present if let Some(tls_config) = &self.tls_config { if let Some(cert_file) = &tls_config.cert_file { @@ -781,6 +862,14 @@ impl Database for RemoteDatabase { } async fn namespace_client_config(&self) -> Result<(String, HashMap)> { + if self.namespace_context_provider.is_some() { + return Err(Error::NotSupported { + message: + "Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead" + .to_string(), + }); + } + let mut properties = HashMap::new(); properties.insert("uri".to_string(), self.client.host().to_string()); properties.insert("delimiter".to_string(), self.client.id_delimiter.clone()); @@ -832,12 +921,13 @@ impl From for RemoteOptions { #[cfg(test)] mod tests { - use super::build_cache_key; + use super::{NamespaceHeaderProviderContext, build_cache_key}; use std::collections::HashMap; use std::sync::{Arc, OnceLock}; use arrow_array::{Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; + use lance_namespace_impls::{DynamicContextProvider, OperationInfo}; use crate::connection::ConnectBuilder; use crate::{ @@ -1702,6 +1792,75 @@ mod tests { assert!(namespace_client.is_ok()); } + #[test] + fn test_namespace_header_provider_context_maps_headers() { + #[derive(Debug)] + struct TestHeaderProvider; + + #[async_trait::async_trait] + impl HeaderProvider for TestHeaderProvider { + async fn get_headers(&self) -> crate::Result> { + Ok(HashMap::from([( + "authorization".to_string(), + "Bearer token".to_string(), + )])) + } + } + + let context_provider = NamespaceHeaderProviderContext { + header_provider: Arc::new(TestHeaderProvider) as Arc, + }; + + let context = + context_provider.provide_context(&OperationInfo::new("list_tables", "namespace")); + + assert_eq!( + context.get("headers.authorization"), + Some(&"Bearer token".to_string()) + ); + } + + #[tokio::test] + async fn test_namespace_client_supports_dynamic_headers() { + #[derive(Debug)] + struct TestHeaderProvider; + + #[async_trait::async_trait] + impl HeaderProvider for TestHeaderProvider { + async fn get_headers(&self) -> crate::Result> { + Ok(HashMap::from([( + "authorization".to_string(), + "Bearer token".to_string(), + )])) + } + } + + let client_config = ClientConfig { + header_provider: Some(Arc::new(TestHeaderProvider) as Arc), + ..Default::default() + }; + + let conn = Connection::new_with_handler_and_config( + |_| { + http::Response::builder() + .status(200) + .body(r#"{"tables": []}"#) + .unwrap() + }, + client_config, + ); + + let namespace_client = conn.namespace_client().await; + assert!(namespace_client.is_ok()); + + match conn.namespace_client_config().await { + Err(Error::NotSupported { message }) + if message.contains("dynamic headers are configured") => {} + Err(err) => panic!("expected NotSupported, got {err:?}"), + Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"), + } + } + /// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server mod rest_adapter_integration { use super::*;