mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-24 22:50:40 +00:00
Compare commits
2 Commits
rpgreen/fi
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe287dc98c | ||
|
|
411568b72c |
@@ -576,6 +576,9 @@ impl Connection {
|
||||
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
|
||||
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
|
||||
/// For RemoteDatabase, it is the equivalent RestNamespace.
|
||||
///
|
||||
/// Remote connections using dynamic headers forward them through the
|
||||
/// namespace client's per-request context provider.
|
||||
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
|
||||
self.internal.namespace_client().await
|
||||
}
|
||||
@@ -584,6 +587,9 @@ impl Connection {
|
||||
/// Returns (impl_type, properties) where:
|
||||
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||
/// - properties: configuration properties for the namespace
|
||||
///
|
||||
/// Remote connections using dynamic headers cannot be exported because the
|
||||
/// namespace client config only carries static headers.
|
||||
pub async fn namespace_client_config(
|
||||
&self,
|
||||
) -> Result<(String, std::collections::HashMap<String, String>)> {
|
||||
|
||||
@@ -459,12 +459,14 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
config: &ClientConfig,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
if !api_key.is_empty() {
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
if region == "local" {
|
||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||
headers.insert(
|
||||
@@ -1005,6 +1007,33 @@ mod tests {
|
||||
assert!(!config_tls.assert_hostname);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_headers_skip_empty_api_key() {
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"",
|
||||
"us-east-1",
|
||||
"db-name",
|
||||
false,
|
||||
&RemoteOptions::default(),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(!headers.contains_key("x-api-key"));
|
||||
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"api-key",
|
||||
"us-east-1",
|
||||
"db-name",
|
||||
false,
|
||||
&RemoteOptions::default(),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(headers.get("x-api-key").unwrap(), "api-key");
|
||||
}
|
||||
|
||||
// Test implementation of HeaderProvider
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestHeaderProvider {
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use lance_io::object_store::StorageOptions;
|
||||
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||
use moka::future::Cache;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
|
||||
@@ -26,7 +27,9 @@ use crate::remote::util::stream_as_body;
|
||||
use crate::table::BaseTable;
|
||||
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::client::{
|
||||
ClientConfig, HeaderProvider, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender,
|
||||
};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::parse_server_version;
|
||||
|
||||
@@ -194,10 +197,66 @@ pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
uri: String,
|
||||
/// Headers to pass to the namespace client for authentication
|
||||
namespace_headers: HashMap<String, String>,
|
||||
namespace_context_provider: Option<Arc<dyn DynamicContextProvider>>,
|
||||
/// TLS configuration for mTLS support
|
||||
tls_config: Option<super::client::TlsConfig>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NamespaceHeaderProviderContext {
|
||||
header_provider: Arc<dyn HeaderProvider>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NamespaceHeaderProviderContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("NamespaceHeaderProviderContext")
|
||||
.field("header_provider", &"Some(...)")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl DynamicContextProvider for NamespaceHeaderProviderContext {
|
||||
fn provide_context(&self, _info: &OperationInfo) -> HashMap<String, String> {
|
||||
let header_provider = Arc::clone(&self.header_provider);
|
||||
let handle = match std::thread::Builder::new()
|
||||
.name("lancedb-namespace-headers".to_string())
|
||||
.spawn(move || {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!(
|
||||
"Failed to create runtime for namespace header provider: {e}"
|
||||
),
|
||||
})?
|
||||
.block_on(header_provider.get_headers())
|
||||
}) {
|
||||
Ok(handle) => handle,
|
||||
Err(err) => {
|
||||
log::warn!("Failed to spawn dynamic namespace header provider thread: {err}");
|
||||
return HashMap::new();
|
||||
}
|
||||
};
|
||||
|
||||
let headers = handle.join();
|
||||
|
||||
match headers {
|
||||
Ok(Ok(headers)) => headers
|
||||
.into_iter()
|
||||
.map(|(key, value)| (format!("headers.{key}"), value))
|
||||
.collect(),
|
||||
Ok(Err(err)) => {
|
||||
log::warn!("Failed to get dynamic namespace headers: {err}");
|
||||
HashMap::new()
|
||||
}
|
||||
Err(_) => {
|
||||
log::warn!("Dynamic namespace header provider panicked");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
pub fn try_new(
|
||||
uri: &str,
|
||||
@@ -228,6 +287,16 @@ impl RemoteDatabase {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let namespace_context_provider =
|
||||
client_config
|
||||
.header_provider
|
||||
.as_ref()
|
||||
.map(|header_provider| {
|
||||
Arc::new(NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::clone(header_provider),
|
||||
}) as Arc<dyn DynamicContextProvider>
|
||||
});
|
||||
|
||||
let client = RestfulLanceDbClient::try_new(
|
||||
&parsed,
|
||||
region,
|
||||
@@ -247,6 +316,7 @@ impl RemoteDatabase {
|
||||
table_cache,
|
||||
uri: uri.to_owned(),
|
||||
namespace_headers,
|
||||
namespace_context_provider,
|
||||
tls_config: client_config.tls_config,
|
||||
})
|
||||
}
|
||||
@@ -271,6 +341,7 @@ mod test_utils {
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: HashMap::new(),
|
||||
namespace_context_provider: None,
|
||||
tls_config: None,
|
||||
}
|
||||
}
|
||||
@@ -281,11 +352,18 @@ mod test_utils {
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler_and_config(handler, config.clone());
|
||||
let namespace_context_provider =
|
||||
config.header_provider.as_ref().map(|header_provider| {
|
||||
Arc::new(NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::clone(header_provider),
|
||||
}) as Arc<dyn DynamicContextProvider>
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: config.extra_headers.clone(),
|
||||
namespace_context_provider,
|
||||
tls_config: config.tls_config.clone(),
|
||||
}
|
||||
}
|
||||
@@ -759,9 +837,12 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
// Create a RestNamespace pointing to the same remote host with the same authentication headers
|
||||
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
|
||||
.delimiter(&self.client.id_delimiter)
|
||||
// TODO: support header provider
|
||||
.headers(self.namespace_headers.clone());
|
||||
|
||||
if let Some(context_provider) = &self.namespace_context_provider {
|
||||
builder = builder.context_provider(Arc::clone(context_provider));
|
||||
}
|
||||
|
||||
// Apply mTLS configuration if present
|
||||
if let Some(tls_config) = &self.tls_config {
|
||||
if let Some(cert_file) = &tls_config.cert_file {
|
||||
@@ -781,6 +862,14 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||
if self.namespace_context_provider.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
message:
|
||||
"Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert("uri".to_string(), self.client.host().to_string());
|
||||
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
|
||||
@@ -832,12 +921,13 @@ impl From<StorageOptions> for RemoteOptions {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::build_cache_key;
|
||||
use super::{NamespaceHeaderProviderContext, build_cache_key};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::{
|
||||
@@ -1702,6 +1792,75 @@ mod tests {
|
||||
assert!(namespace_client.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_namespace_header_provider_context_maps_headers() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let context_provider = NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>,
|
||||
};
|
||||
|
||||
let context =
|
||||
context_provider.provide_context(&OperationInfo::new("list_tables", "namespace"));
|
||||
|
||||
assert_eq!(
|
||||
context.get("headers.authorization"),
|
||||
Some(&"Bearer token".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_namespace_client_supports_dynamic_headers() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
|_| {
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": []}"#)
|
||||
.unwrap()
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
let namespace_client = conn.namespace_client().await;
|
||||
assert!(namespace_client.is_ok());
|
||||
|
||||
match conn.namespace_client_config().await {
|
||||
Err(Error::NotSupported { message })
|
||||
if message.contains("dynamic headers are configured") => {}
|
||||
Err(err) => panic!("expected NotSupported, got {err:?}"),
|
||||
Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
|
||||
mod rest_adapter_integration {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user