mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
feat: support per-request header override (#2631)
## Summary This PR introduces a `HeaderProvider` which is called for all remote HTTP calls to get the latest headers to inject. This is useful for features like adding the latest auth tokens where the header provider can auto-refresh tokens internally and each request always set the refreshed token. --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -998,6 +998,23 @@ mod test_utils {
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_handler_and_config<T>(
|
||||
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
|
||||
config: crate::remote::ClientConfig,
|
||||
) -> Self
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config(
|
||||
handler, config,
|
||||
));
|
||||
Self {
|
||||
internal,
|
||||
uri: "db://test".to_string(),
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,5 +18,5 @@ const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
|
||||
#[cfg(test)]
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig, TlsConfig};
|
||||
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
||||
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
||||
|
||||
@@ -7,7 +7,7 @@ use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
Body, Request, RequestBuilder, Response,
|
||||
};
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, sync::Arc, time::Duration};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::db::RemoteOptions;
|
||||
@@ -28,8 +28,15 @@ pub struct TlsConfig {
|
||||
pub assert_hostname: bool,
|
||||
}
|
||||
|
||||
/// Trait for providing custom headers for each request
|
||||
#[async_trait::async_trait]
|
||||
pub trait HeaderProvider: Send + Sync + std::fmt::Debug {
|
||||
/// Get the latest headers to be added to the request
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>>;
|
||||
}
|
||||
|
||||
/// Configuration for the LanceDB Cloud HTTP client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct ClientConfig {
|
||||
pub timeout_config: TimeoutConfig,
|
||||
pub retry_config: RetryConfig,
|
||||
@@ -43,6 +50,25 @@ pub struct ClientConfig {
|
||||
pub id_delimiter: Option<String>,
|
||||
/// TLS configuration for mTLS support
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
/// Provider for custom headers to be added to each request
|
||||
pub header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ClientConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ClientConfig")
|
||||
.field("timeout_config", &self.timeout_config)
|
||||
.field("retry_config", &self.retry_config)
|
||||
.field("user_agent", &self.user_agent)
|
||||
.field("extra_headers", &self.extra_headers)
|
||||
.field("id_delimiter", &self.id_delimiter)
|
||||
.field("tls_config", &self.tls_config)
|
||||
.field(
|
||||
"header_provider",
|
||||
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
@@ -54,6 +80,7 @@ impl Default for ClientConfig {
|
||||
extra_headers: HashMap::new(),
|
||||
id_delimiter: None,
|
||||
tls_config: None,
|
||||
header_provider: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -159,13 +186,29 @@ pub struct RetryConfig {
|
||||
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
|
||||
// we can mock responses in tests. Based on the patterns from this blog post:
|
||||
// https://write.as/balrogboogie/testing-reqwest-based-clients
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
|
||||
client: reqwest::Client,
|
||||
host: String,
|
||||
pub(crate) retry_config: ResolvedRetryConfig,
|
||||
pub(crate) sender: S,
|
||||
pub(crate) id_delimiter: String,
|
||||
pub(crate) header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> std::fmt::Debug for RestfulLanceDbClient<S> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RestfulLanceDbClient")
|
||||
.field("host", &self.host)
|
||||
.field("retry_config", &self.retry_config)
|
||||
.field("sender", &self.sender)
|
||||
.field("id_delimiter", &self.id_delimiter)
|
||||
.field(
|
||||
"header_provider",
|
||||
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
|
||||
@@ -326,13 +369,17 @@ impl RestfulLanceDbClient<Sender> {
|
||||
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
|
||||
};
|
||||
debug!("Created client for host: {}", host);
|
||||
let retry_config = client_config.retry_config.try_into()?;
|
||||
let retry_config = client_config.retry_config.clone().try_into()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
retry_config,
|
||||
sender: Sender,
|
||||
id_delimiter: client_config.id_delimiter.unwrap_or("$".to_string()),
|
||||
id_delimiter: client_config
|
||||
.id_delimiter
|
||||
.clone()
|
||||
.unwrap_or("$".to_string()),
|
||||
header_provider: client_config.header_provider,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -439,10 +486,34 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply dynamic headers from the header provider if configured
|
||||
async fn apply_dynamic_headers(&self, mut request: Request) -> Result<Request> {
|
||||
if let Some(ref provider) = self.header_provider {
|
||||
let headers = provider.get_headers().await?;
|
||||
let request_headers = request.headers_mut();
|
||||
for (key, value) in headers {
|
||||
if let Ok(header_name) = HeaderName::from_str(&key) {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&value) {
|
||||
request_headers.insert(header_name, header_value);
|
||||
} else {
|
||||
debug!("Invalid header value for key {}: {}", key, value);
|
||||
}
|
||||
} else {
|
||||
debug!("Invalid header name: {}", key);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> {
|
||||
let (client, request) = req.build_split();
|
||||
let mut request = request.unwrap();
|
||||
let request_id = self.extract_request_id(&mut request);
|
||||
|
||||
// Apply dynamic headers before sending
|
||||
request = self.apply_dynamic_headers(request).await?;
|
||||
|
||||
self.log_request(&request, &request_id);
|
||||
|
||||
let response = self
|
||||
@@ -498,6 +569,10 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
let (c, request) = req_builder.build_split();
|
||||
let mut request = request.unwrap();
|
||||
self.set_request_id(&mut request, &request_id.clone());
|
||||
|
||||
// Apply dynamic headers before each retry attempt
|
||||
request = self.apply_dynamic_headers(request).await?;
|
||||
|
||||
self.log_request(&request, &request_id);
|
||||
|
||||
let response = self.sender.send(&c, request).await.map(|r| (r.status(), r));
|
||||
@@ -625,6 +700,7 @@ impl<T> RequestResultExt for reqwest::Result<T> {
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use std::convert::TryInto;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
@@ -670,6 +746,31 @@ pub mod test_utils {
|
||||
f: Arc::new(wrapper),
|
||||
},
|
||||
id_delimiter: "$".to_string(),
|
||||
header_provider: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn client_with_handler_and_config<T>(
|
||||
handler: impl Fn(reqwest::Request) -> http::response::Response<T> + Send + Sync + 'static,
|
||||
config: ClientConfig,
|
||||
) -> RestfulLanceDbClient<MockSender>
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let wrapper = move |req: reqwest::Request| {
|
||||
let response = handler(req);
|
||||
response.into()
|
||||
};
|
||||
|
||||
RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "http://localhost".to_string(),
|
||||
retry_config: config.retry_config.try_into().unwrap(),
|
||||
sender: MockSender {
|
||||
f: Arc::new(wrapper),
|
||||
},
|
||||
id_delimiter: config.id_delimiter.unwrap_or_else(|| "$".to_string()),
|
||||
header_provider: config.header_provider,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -766,4 +867,159 @@ mod tests {
|
||||
assert!(config_tls.ssl_ca_cert.is_none());
|
||||
assert!(!config_tls.assert_hostname);
|
||||
}
|
||||
|
||||
// Test implementation of HeaderProvider
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestHeaderProvider {
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl TestHeaderProvider {
|
||||
fn new(headers: HashMap<String, String>) -> Self {
|
||||
Self { headers }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
Ok(self.headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// Test implementation that returns an error
|
||||
#[derive(Debug)]
|
||||
struct ErrorHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for ErrorHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
Err(Error::Runtime {
|
||||
message: "Failed to get headers".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_config_with_header_provider() {
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-API-Key".to_string(), "secret-key".to_string());
|
||||
|
||||
let provider = TestHeaderProvider::new(headers);
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert!(client_config.header_provider.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_dynamic_headers() {
|
||||
// Create a mock client with header provider
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Dynamic".to_string(), "dynamic-value".to_string());
|
||||
|
||||
let provider = TestHeaderProvider::new(headers);
|
||||
|
||||
// Create a simple request
|
||||
let request = reqwest::Request::new(
|
||||
reqwest::Method::GET,
|
||||
"https://example.com/test".parse().unwrap(),
|
||||
);
|
||||
|
||||
// Create client with header provider
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
};
|
||||
|
||||
// Apply dynamic headers
|
||||
let updated_request = client.apply_dynamic_headers(request).await.unwrap();
|
||||
|
||||
// Check that the header was added
|
||||
assert_eq!(
|
||||
updated_request.headers().get("X-Dynamic").unwrap(),
|
||||
"dynamic-value"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_dynamic_headers_merge() {
|
||||
// Test that dynamic headers override existing headers
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("Authorization".to_string(), "Bearer new-token".to_string());
|
||||
headers.insert("X-Custom".to_string(), "custom-value".to_string());
|
||||
|
||||
let provider = TestHeaderProvider::new(headers);
|
||||
|
||||
// Create request with existing Authorization header
|
||||
let mut request_builder = reqwest::Client::new().get("https://example.com/test");
|
||||
request_builder = request_builder.header("Authorization", "Bearer old-token");
|
||||
request_builder = request_builder.header("X-Existing", "existing-value");
|
||||
let request = request_builder.build().unwrap();
|
||||
|
||||
// Create client with header provider
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
};
|
||||
|
||||
// Apply dynamic headers
|
||||
let updated_request = client.apply_dynamic_headers(request).await.unwrap();
|
||||
|
||||
// Check that dynamic headers override existing ones
|
||||
assert_eq!(
|
||||
updated_request.headers().get("Authorization").unwrap(),
|
||||
"Bearer new-token"
|
||||
);
|
||||
assert_eq!(
|
||||
updated_request.headers().get("X-Custom").unwrap(),
|
||||
"custom-value"
|
||||
);
|
||||
// Existing headers should still be present
|
||||
assert_eq!(
|
||||
updated_request.headers().get("X-Existing").unwrap(),
|
||||
"existing-value"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_dynamic_headers_with_error_provider() {
|
||||
let provider = ErrorHeaderProvider;
|
||||
|
||||
let request = reqwest::Request::new(
|
||||
reqwest::Method::GET,
|
||||
"https://example.com/test".parse().unwrap(),
|
||||
);
|
||||
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
};
|
||||
|
||||
// Header provider errors should fail the request
|
||||
// This is important for security - if auth headers can't be fetched, don't proceed
|
||||
let result = client.apply_dynamic_headers(request).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
Error::Runtime { message } => {
|
||||
assert_eq!(message, "Failed to get headers");
|
||||
}
|
||||
_ => panic!("Expected Runtime error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,8 +212,9 @@ impl RemoteDatabase {
|
||||
#[cfg(all(test, feature = "remote"))]
|
||||
mod test_utils {
|
||||
use super::*;
|
||||
use crate::remote::client::test_utils::client_with_handler;
|
||||
use crate::remote::client::test_utils::MockSender;
|
||||
use crate::remote::client::test_utils::{client_with_handler, client_with_handler_and_config};
|
||||
use crate::remote::ClientConfig;
|
||||
|
||||
impl RemoteDatabase<MockSender> {
|
||||
pub fn new_mock<F, T>(handler: F) -> Self
|
||||
@@ -227,6 +228,18 @@ mod test_utils {
|
||||
table_cache: Cache::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_mock_with_config<F, T>(handler: F, config: ClientConfig) -> Self
|
||||
where
|
||||
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler_and_config(handler, config);
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -587,6 +600,7 @@ impl From<StorageOptions> for RemoteOptions {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::build_cache_key;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
@@ -595,7 +609,7 @@ mod tests {
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::{
|
||||
database::CreateTableMode,
|
||||
remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
|
||||
remote::{ClientConfig, HeaderProvider, ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
|
||||
Connection, Error,
|
||||
};
|
||||
|
||||
@@ -1112,4 +1126,99 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_header_provider_in_request() {
|
||||
// Test HeaderProvider implementation that adds custom headers
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestHeaderProvider {
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(self.headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// Create a test header provider with custom headers
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Custom-Auth".to_string(), "test-token".to_string());
|
||||
headers.insert("X-Request-Id".to_string(), "test-123".to_string());
|
||||
let provider = Arc::new(TestHeaderProvider { headers }) as Arc<dyn HeaderProvider>;
|
||||
|
||||
// Create client config with the header provider
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(provider),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create connection with handler that verifies the headers are present
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
move |request| {
|
||||
// Verify that our custom headers are present
|
||||
assert_eq!(
|
||||
request.headers().get("X-Custom-Auth").unwrap(),
|
||||
"test-token"
|
||||
);
|
||||
assert_eq!(request.headers().get("X-Request-Id").unwrap(), "test-123");
|
||||
|
||||
// Also check standard headers are still there
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/table/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": ["table1", "table2"]}"#)
|
||||
.unwrap()
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
// Make a request that should include the custom headers
|
||||
let names = conn.table_names().execute().await.unwrap();
|
||||
assert_eq!(names, vec!["table1", "table2"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_header_provider_error_handling() {
|
||||
// Test HeaderProvider that returns an error
|
||||
#[derive(Debug)]
|
||||
struct ErrorHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for ErrorHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Err(crate::Error::Runtime {
|
||||
message: "Failed to fetch auth token".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let provider = Arc::new(ErrorHeaderProvider) as Arc<dyn HeaderProvider>;
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(provider),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create connection - handler won't be called because header provider fails
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
move |_request| -> http::Response<&'static str> {
|
||||
panic!("Handler should not be called when header provider fails");
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
// Request should fail due to header provider error
|
||||
let result = conn.table_names().execute().await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
crate::Error::Runtime { message } => {
|
||||
assert_eq!(message, "Failed to fetch auth token");
|
||||
}
|
||||
_ => panic!("Expected Runtime error from header provider"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user