feat: support per-request header override (#2631)

## Summary

This PR introduces a `HeaderProvider` which is called for all remote
HTTP calls to get the latest headers to inject. This is useful for
features like adding the latest auth tokens where the header provider
can auto-refresh tokens internally and each request always set the
refreshed token.

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jack Ye
2025-09-10 13:44:00 -07:00
committed by GitHub
parent 3c7419b392
commit 8da74dcb37
31 changed files with 2639 additions and 49 deletions

View File

@@ -998,6 +998,23 @@ mod test_utils {
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
pub fn new_with_handler_and_config<T>(
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
config: crate::remote::ClientConfig,
) -> Self
where
T: Into<reqwest::Body>,
{
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config(
handler, config,
));
Self {
internal,
uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
}
}

View File

@@ -18,5 +18,5 @@ const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
#[cfg(test)]
const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, RetryConfig, TimeoutConfig, TlsConfig};
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};

View File

@@ -7,7 +7,7 @@ use reqwest::{
header::{HeaderMap, HeaderValue},
Body, Request, RequestBuilder, Response,
};
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use std::{collections::HashMap, future::Future, str::FromStr, sync::Arc, time::Duration};
use crate::error::{Error, Result};
use crate::remote::db::RemoteOptions;
@@ -28,8 +28,15 @@ pub struct TlsConfig {
pub assert_hostname: bool,
}
/// Trait for providing custom headers for each request
#[async_trait::async_trait]
pub trait HeaderProvider: Send + Sync + std::fmt::Debug {
/// Get the latest headers to be added to the request
async fn get_headers(&self) -> Result<HashMap<String, String>>;
}
/// Configuration for the LanceDB Cloud HTTP client.
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct ClientConfig {
pub timeout_config: TimeoutConfig,
pub retry_config: RetryConfig,
@@ -43,6 +50,25 @@ pub struct ClientConfig {
pub id_delimiter: Option<String>,
/// TLS configuration for mTLS support
pub tls_config: Option<TlsConfig>,
/// Provider for custom headers to be added to each request
pub header_provider: Option<Arc<dyn HeaderProvider>>,
}
impl std::fmt::Debug for ClientConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientConfig")
.field("timeout_config", &self.timeout_config)
.field("retry_config", &self.retry_config)
.field("user_agent", &self.user_agent)
.field("extra_headers", &self.extra_headers)
.field("id_delimiter", &self.id_delimiter)
.field("tls_config", &self.tls_config)
.field(
"header_provider",
&self.header_provider.as_ref().map(|_| "Some(...)"),
)
.finish()
}
}
impl Default for ClientConfig {
@@ -54,6 +80,7 @@ impl Default for ClientConfig {
extra_headers: HashMap::new(),
id_delimiter: None,
tls_config: None,
header_provider: None,
}
}
}
@@ -159,13 +186,29 @@ pub struct RetryConfig {
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
// we can mock responses in tests. Based on the patterns from this blog post:
// https://write.as/balrogboogie/testing-reqwest-based-clients
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
client: reqwest::Client,
host: String,
pub(crate) retry_config: ResolvedRetryConfig,
pub(crate) sender: S,
pub(crate) id_delimiter: String,
pub(crate) header_provider: Option<Arc<dyn HeaderProvider>>,
}
impl<S: HttpSend> std::fmt::Debug for RestfulLanceDbClient<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RestfulLanceDbClient")
.field("host", &self.host)
.field("retry_config", &self.retry_config)
.field("sender", &self.sender)
.field("id_delimiter", &self.id_delimiter)
.field(
"header_provider",
&self.header_provider.as_ref().map(|_| "Some(...)"),
)
.finish()
}
}
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
@@ -326,13 +369,17 @@ impl RestfulLanceDbClient<Sender> {
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
};
debug!("Created client for host: {}", host);
let retry_config = client_config.retry_config.try_into()?;
let retry_config = client_config.retry_config.clone().try_into()?;
Ok(Self {
client,
host,
retry_config,
sender: Sender,
id_delimiter: client_config.id_delimiter.unwrap_or("$".to_string()),
id_delimiter: client_config
.id_delimiter
.clone()
.unwrap_or("$".to_string()),
header_provider: client_config.header_provider,
})
}
}
@@ -439,10 +486,34 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
}
/// Apply dynamic headers from the header provider if configured
async fn apply_dynamic_headers(&self, mut request: Request) -> Result<Request> {
if let Some(ref provider) = self.header_provider {
let headers = provider.get_headers().await?;
let request_headers = request.headers_mut();
for (key, value) in headers {
if let Ok(header_name) = HeaderName::from_str(&key) {
if let Ok(header_value) = HeaderValue::from_str(&value) {
request_headers.insert(header_name, header_value);
} else {
debug!("Invalid header value for key {}: {}", key, value);
}
} else {
debug!("Invalid header name: {}", key);
}
}
}
Ok(request)
}
pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> {
let (client, request) = req.build_split();
let mut request = request.unwrap();
let request_id = self.extract_request_id(&mut request);
// Apply dynamic headers before sending
request = self.apply_dynamic_headers(request).await?;
self.log_request(&request, &request_id);
let response = self
@@ -498,6 +569,10 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
let (c, request) = req_builder.build_split();
let mut request = request.unwrap();
self.set_request_id(&mut request, &request_id.clone());
// Apply dynamic headers before each retry attempt
request = self.apply_dynamic_headers(request).await?;
self.log_request(&request, &request_id);
let response = self.sender.send(&c, request).await.map(|r| (r.status(), r));
@@ -625,6 +700,7 @@ impl<T> RequestResultExt for reqwest::Result<T> {
#[cfg(test)]
pub mod test_utils {
use std::convert::TryInto;
use std::sync::Arc;
use super::*;
@@ -670,6 +746,31 @@ pub mod test_utils {
f: Arc::new(wrapper),
},
id_delimiter: "$".to_string(),
header_provider: None,
}
}
pub fn client_with_handler_and_config<T>(
handler: impl Fn(reqwest::Request) -> http::response::Response<T> + Send + Sync + 'static,
config: ClientConfig,
) -> RestfulLanceDbClient<MockSender>
where
T: Into<reqwest::Body>,
{
let wrapper = move |req: reqwest::Request| {
let response = handler(req);
response.into()
};
RestfulLanceDbClient {
client: reqwest::Client::new(),
host: "http://localhost".to_string(),
retry_config: config.retry_config.try_into().unwrap(),
sender: MockSender {
f: Arc::new(wrapper),
},
id_delimiter: config.id_delimiter.unwrap_or_else(|| "$".to_string()),
header_provider: config.header_provider,
}
}
}
@@ -766,4 +867,159 @@ mod tests {
assert!(config_tls.ssl_ca_cert.is_none());
assert!(!config_tls.assert_hostname);
}
// Test implementation of HeaderProvider
#[derive(Debug, Clone)]
struct TestHeaderProvider {
headers: HashMap<String, String>,
}
impl TestHeaderProvider {
fn new(headers: HashMap<String, String>) -> Self {
Self { headers }
}
}
#[async_trait::async_trait]
impl HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
Ok(self.headers.clone())
}
}
// Test implementation that returns an error
#[derive(Debug)]
struct ErrorHeaderProvider;
#[async_trait::async_trait]
impl HeaderProvider for ErrorHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
Err(Error::Runtime {
message: "Failed to get headers".to_string(),
})
}
}
#[tokio::test]
async fn test_client_config_with_header_provider() {
let mut headers = HashMap::new();
headers.insert("X-API-Key".to_string(), "secret-key".to_string());
let provider = TestHeaderProvider::new(headers);
let client_config = ClientConfig {
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
..Default::default()
};
assert!(client_config.header_provider.is_some());
}
#[tokio::test]
async fn test_apply_dynamic_headers() {
// Create a mock client with header provider
let mut headers = HashMap::new();
headers.insert("X-Dynamic".to_string(), "dynamic-value".to_string());
let provider = TestHeaderProvider::new(headers);
// Create a simple request
let request = reqwest::Request::new(
reqwest::Method::GET,
"https://example.com/test".parse().unwrap(),
);
// Create client with header provider
let client = RestfulLanceDbClient {
client: reqwest::Client::new(),
host: "https://example.com".to_string(),
retry_config: RetryConfig::default().try_into().unwrap(),
sender: Sender,
id_delimiter: "+".to_string(),
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
};
// Apply dynamic headers
let updated_request = client.apply_dynamic_headers(request).await.unwrap();
// Check that the header was added
assert_eq!(
updated_request.headers().get("X-Dynamic").unwrap(),
"dynamic-value"
);
}
#[tokio::test]
async fn test_apply_dynamic_headers_merge() {
// Test that dynamic headers override existing headers
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer new-token".to_string());
headers.insert("X-Custom".to_string(), "custom-value".to_string());
let provider = TestHeaderProvider::new(headers);
// Create request with existing Authorization header
let mut request_builder = reqwest::Client::new().get("https://example.com/test");
request_builder = request_builder.header("Authorization", "Bearer old-token");
request_builder = request_builder.header("X-Existing", "existing-value");
let request = request_builder.build().unwrap();
// Create client with header provider
let client = RestfulLanceDbClient {
client: reqwest::Client::new(),
host: "https://example.com".to_string(),
retry_config: RetryConfig::default().try_into().unwrap(),
sender: Sender,
id_delimiter: "+".to_string(),
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
};
// Apply dynamic headers
let updated_request = client.apply_dynamic_headers(request).await.unwrap();
// Check that dynamic headers override existing ones
assert_eq!(
updated_request.headers().get("Authorization").unwrap(),
"Bearer new-token"
);
assert_eq!(
updated_request.headers().get("X-Custom").unwrap(),
"custom-value"
);
// Existing headers should still be present
assert_eq!(
updated_request.headers().get("X-Existing").unwrap(),
"existing-value"
);
}
#[tokio::test]
async fn test_apply_dynamic_headers_with_error_provider() {
let provider = ErrorHeaderProvider;
let request = reqwest::Request::new(
reqwest::Method::GET,
"https://example.com/test".parse().unwrap(),
);
let client = RestfulLanceDbClient {
client: reqwest::Client::new(),
host: "https://example.com".to_string(),
retry_config: RetryConfig::default().try_into().unwrap(),
sender: Sender,
id_delimiter: "+".to_string(),
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
};
// Header provider errors should fail the request
// This is important for security - if auth headers can't be fetched, don't proceed
let result = client.apply_dynamic_headers(request).await;
assert!(result.is_err());
match result.unwrap_err() {
Error::Runtime { message } => {
assert_eq!(message, "Failed to get headers");
}
_ => panic!("Expected Runtime error"),
}
}
}

View File

@@ -212,8 +212,9 @@ impl RemoteDatabase {
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
use crate::remote::client::test_utils::client_with_handler;
use crate::remote::client::test_utils::MockSender;
use crate::remote::client::test_utils::{client_with_handler, client_with_handler_and_config};
use crate::remote::ClientConfig;
impl RemoteDatabase<MockSender> {
pub fn new_mock<F, T>(handler: F) -> Self
@@ -227,6 +228,18 @@ mod test_utils {
table_cache: Cache::new(0),
}
}
pub fn new_mock_with_config<F, T>(handler: F, config: ClientConfig) -> Self
where
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
T: Into<reqwest::Body>,
{
let client = client_with_handler_and_config(handler, config);
Self {
client,
table_cache: Cache::new(0),
}
}
}
}
@@ -587,6 +600,7 @@ impl From<StorageOptions> for RemoteOptions {
#[cfg(test)]
mod tests {
use super::build_cache_key;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
@@ -595,7 +609,7 @@ mod tests {
use crate::connection::ConnectBuilder;
use crate::{
database::CreateTableMode,
remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
remote::{ClientConfig, HeaderProvider, ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
Connection, Error,
};
@@ -1112,4 +1126,99 @@ mod tests {
.await
.unwrap();
}
#[tokio::test]
async fn test_header_provider_in_request() {
// Test HeaderProvider implementation that adds custom headers
#[derive(Debug, Clone)]
struct TestHeaderProvider {
headers: HashMap<String, String>,
}
#[async_trait::async_trait]
impl HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
Ok(self.headers.clone())
}
}
// Create a test header provider with custom headers
let mut headers = HashMap::new();
headers.insert("X-Custom-Auth".to_string(), "test-token".to_string());
headers.insert("X-Request-Id".to_string(), "test-123".to_string());
let provider = Arc::new(TestHeaderProvider { headers }) as Arc<dyn HeaderProvider>;
// Create client config with the header provider
let client_config = ClientConfig {
header_provider: Some(provider),
..Default::default()
};
// Create connection with handler that verifies the headers are present
let conn = Connection::new_with_handler_and_config(
move |request| {
// Verify that our custom headers are present
assert_eq!(
request.headers().get("X-Custom-Auth").unwrap(),
"test-token"
);
assert_eq!(request.headers().get("X-Request-Id").unwrap(), "test-123");
// Also check standard headers are still there
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/table/");
http::Response::builder()
.status(200)
.body(r#"{"tables": ["table1", "table2"]}"#)
.unwrap()
},
client_config,
);
// Make a request that should include the custom headers
let names = conn.table_names().execute().await.unwrap();
assert_eq!(names, vec!["table1", "table2"]);
}
#[tokio::test]
async fn test_header_provider_error_handling() {
// Test HeaderProvider that returns an error
#[derive(Debug)]
struct ErrorHeaderProvider;
#[async_trait::async_trait]
impl HeaderProvider for ErrorHeaderProvider {
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
Err(crate::Error::Runtime {
message: "Failed to fetch auth token".to_string(),
})
}
}
let provider = Arc::new(ErrorHeaderProvider) as Arc<dyn HeaderProvider>;
let client_config = ClientConfig {
header_provider: Some(provider),
..Default::default()
};
// Create connection - handler won't be called because header provider fails
let conn = Connection::new_with_handler_and_config(
move |_request| -> http::Response<&'static str> {
panic!("Handler should not be called when header provider fails");
},
client_config,
);
// Request should fail due to header provider error
let result = conn.table_names().execute().await;
assert!(result.is_err());
match result.unwrap_err() {
crate::Error::Runtime { message } => {
assert_eq!(message, "Failed to fetch auth token");
}
_ => panic!("Expected Runtime error from header provider"),
}
}
}