From 9391ad1450b678c21486e7567b9e054a16ca936d Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Tue, 9 Sep 2025 21:04:46 -0700 Subject: [PATCH] feat: support mTLS for remote database (#2638) This PR adds mTLS (mutual TLS) configuration support for the LanceDB remote HTTP client, allowing users to authenticate with client certificates and configure custom CA certificates for server verification. --------- Co-authored-by: Claude --- nodejs/__test__/remote.test.ts | 92 +++++++++++++++++++- nodejs/lancedb/index.ts | 1 + nodejs/src/remote.rs | 27 ++++++ python/python/lancedb/remote/__init__.py | 28 +++++- python/src/connection.rs | 22 +++++ rust/lancedb/src/remote.rs | 2 +- rust/lancedb/src/remote/client.rs | 105 +++++++++++++++++++++++ 7 files changed, 274 insertions(+), 3 deletions(-) diff --git a/nodejs/__test__/remote.test.ts b/nodejs/__test__/remote.test.ts index 96a45986..b86024de 100644 --- a/nodejs/__test__/remote.test.ts +++ b/nodejs/__test__/remote.test.ts @@ -3,7 +3,13 @@ import * as http from "http"; import { RequestListener } from "http"; -import { Connection, ConnectionOptions, connect } from "../lancedb"; +import { + ClientConfig, + Connection, + ConnectionOptions, + TlsConfig, + connect, +} from "../lancedb"; async function withMockDatabase( listener: RequestListener, @@ -148,4 +154,88 @@ describe("remote connection", () => { }, ); }); + + describe("TlsConfig", () => { + it("should create TlsConfig with all fields", () => { + const tlsConfig: TlsConfig = { + certFile: "/path/to/cert.pem", + keyFile: "/path/to/key.pem", + sslCaCert: "/path/to/ca.pem", + assertHostname: false, + }; + + expect(tlsConfig.certFile).toBe("/path/to/cert.pem"); + expect(tlsConfig.keyFile).toBe("/path/to/key.pem"); + expect(tlsConfig.sslCaCert).toBe("/path/to/ca.pem"); + expect(tlsConfig.assertHostname).toBe(false); + }); + + it("should create TlsConfig with partial fields", () => { + const tlsConfig: TlsConfig = { + certFile: "/path/to/cert.pem", + keyFile: "/path/to/key.pem", + }; + + expect(tlsConfig.certFile).toBe("/path/to/cert.pem"); + expect(tlsConfig.keyFile).toBe("/path/to/key.pem"); + expect(tlsConfig.sslCaCert).toBeUndefined(); + expect(tlsConfig.assertHostname).toBeUndefined(); + }); + + it("should create ClientConfig with TlsConfig", () => { + const tlsConfig: TlsConfig = { + certFile: "/path/to/cert.pem", + keyFile: "/path/to/key.pem", + sslCaCert: "/path/to/ca.pem", + assertHostname: true, + }; + + const clientConfig: ClientConfig = { + userAgent: "test-agent", + tlsConfig: tlsConfig, + }; + + expect(clientConfig.userAgent).toBe("test-agent"); + expect(clientConfig.tlsConfig).toBeDefined(); + expect(clientConfig.tlsConfig?.certFile).toBe("/path/to/cert.pem"); + expect(clientConfig.tlsConfig?.keyFile).toBe("/path/to/key.pem"); + expect(clientConfig.tlsConfig?.sslCaCert).toBe("/path/to/ca.pem"); + expect(clientConfig.tlsConfig?.assertHostname).toBe(true); + }); + + it("should handle empty TlsConfig", () => { + const tlsConfig: TlsConfig = {}; + + expect(tlsConfig.certFile).toBeUndefined(); + expect(tlsConfig.keyFile).toBeUndefined(); + expect(tlsConfig.sslCaCert).toBeUndefined(); + expect(tlsConfig.assertHostname).toBeUndefined(); + }); + + it("should accept TlsConfig in connection options", () => { + const tlsConfig: TlsConfig = { + certFile: "/path/to/cert.pem", + keyFile: "/path/to/key.pem", + sslCaCert: "/path/to/ca.pem", + assertHostname: false, + }; + + // Just verify that the ClientConfig accepts the TlsConfig + const clientConfig: ClientConfig = { + tlsConfig: tlsConfig, + }; + + const connectionOptions: ConnectionOptions = { + apiKey: "fake", + clientConfig: clientConfig, + }; + + // Verify the configuration structure is correct + expect(connectionOptions.clientConfig).toBeDefined(); + expect(connectionOptions.clientConfig?.tlsConfig).toBeDefined(); + expect(connectionOptions.clientConfig?.tlsConfig?.certFile).toBe( + "/path/to/cert.pem", + ); + }); + }); }); diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index e27eb414..54ef67c6 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -21,6 +21,7 @@ export { ClientConfig, TimeoutConfig, RetryConfig, + TlsConfig, OptimizeStats, CompactionStats, RemovalStats, diff --git a/nodejs/src/remote.rs b/nodejs/src/remote.rs index 818734e7..ed2ce830 100644 --- a/nodejs/src/remote.rs +++ b/nodejs/src/remote.rs @@ -69,6 +69,20 @@ pub struct RetryConfig { pub statuses: Option>, } +/// TLS/mTLS configuration for the remote HTTP client. +#[napi(object)] +#[derive(Debug, Default)] +pub struct TlsConfig { + /// Path to the client certificate file (PEM format) for mTLS authentication. + pub cert_file: Option, + /// Path to the client private key file (PEM format) for mTLS authentication. + pub key_file: Option, + /// Path to the CA certificate file (PEM format) for server verification. + pub ssl_ca_cert: Option, + /// Whether to verify the hostname in the server's certificate. + pub assert_hostname: Option, +} + #[napi(object)] #[derive(Debug, Default)] pub struct ClientConfig { @@ -77,6 +91,7 @@ pub struct ClientConfig { pub timeout_config: Option, pub extra_headers: Option>, pub id_delimiter: Option, + pub tls_config: Option, } impl From for lancedb::remote::TimeoutConfig { @@ -107,6 +122,17 @@ impl From for lancedb::remote::RetryConfig { } } +impl From for lancedb::remote::TlsConfig { + fn from(config: TlsConfig) -> Self { + Self { + cert_file: config.cert_file, + key_file: config.key_file, + ssl_ca_cert: config.ssl_ca_cert, + assert_hostname: config.assert_hostname.unwrap_or(true), + } + } +} + impl From for lancedb::remote::ClientConfig { fn from(config: ClientConfig) -> Self { Self { @@ -117,6 +143,7 @@ impl From for lancedb::remote::ClientConfig { timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(), extra_headers: config.extra_headers.unwrap_or_default(), id_delimiter: config.id_delimiter, + tls_config: config.tls_config.map(Into::into), } } } diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 754febf0..0a6234ea 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -8,7 +8,7 @@ from typing import List, Optional from lancedb import __version__ -__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"] +__all__ = ["TimeoutConfig", "RetryConfig", "TlsConfig", "ClientConfig"] @dataclass @@ -112,6 +112,29 @@ class RetryConfig: statuses: Optional[List[int]] = None +@dataclass +class TlsConfig: + """TLS/mTLS configuration for the remote HTTP client. + + Attributes + ---------- + cert_file: Optional[str] + Path to the client certificate file (PEM format) for mTLS authentication. + key_file: Optional[str] + Path to the client private key file (PEM format) for mTLS authentication. + ssl_ca_cert: Optional[str] + Path to the CA certificate file (PEM format) for server verification. + assert_hostname: bool + Whether to verify the hostname in the server's certificate. Default is True. + Set to False to disable hostname verification (use with caution). + """ + + cert_file: Optional[str] = None + key_file: Optional[str] = None + ssl_ca_cert: Optional[str] = None + assert_hostname: bool = True + + @dataclass class ClientConfig: user_agent: str = f"LanceDB-Python-Client/{__version__}" @@ -119,9 +142,12 @@ class ClientConfig: timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig) extra_headers: Optional[dict] = None id_delimiter: Optional[str] = None + tls_config: Optional[TlsConfig] = None def __post_init__(self): if isinstance(self.retry_config, dict): self.retry_config = RetryConfig(**self.retry_config) if isinstance(self.timeout_config, dict): self.timeout_config = TimeoutConfig(**self.timeout_config) + if isinstance(self.tls_config, dict): + self.tls_config = TlsConfig(**self.tls_config) diff --git a/python/src/connection.rs b/python/src/connection.rs index 8e0507f9..1d6a32a5 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -301,6 +301,7 @@ pub struct PyClientConfig { timeout_config: Option, extra_headers: Option>, id_delimiter: Option, + tls_config: Option, } #[derive(FromPyObject)] @@ -321,6 +322,14 @@ pub struct PyClientTimeoutConfig { pool_idle_timeout: Option, } +#[derive(FromPyObject)] +pub struct PyClientTlsConfig { + cert_file: Option, + key_file: Option, + ssl_ca_cert: Option, + assert_hostname: bool, +} + #[cfg(feature = "remote")] impl From for lancedb::remote::RetryConfig { fn from(value: PyClientRetryConfig) -> Self { @@ -347,6 +356,18 @@ impl From for lancedb::remote::TimeoutConfig { } } +#[cfg(feature = "remote")] +impl From for lancedb::remote::TlsConfig { + fn from(value: PyClientTlsConfig) -> Self { + Self { + cert_file: value.cert_file, + key_file: value.key_file, + ssl_ca_cert: value.ssl_ca_cert, + assert_hostname: value.assert_hostname, + } + } +} + #[cfg(feature = "remote")] impl From for lancedb::remote::ClientConfig { fn from(value: PyClientConfig) -> Self { @@ -356,6 +377,7 @@ impl From for lancedb::remote::ClientConfig { timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(), extra_headers: value.extra_headers.unwrap_or_default(), id_delimiter: value.id_delimiter, + tls_config: value.tls_config.map(Into::into), } } } diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index b8bc96b1..dee549e1 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -18,5 +18,5 @@ const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file"; #[cfg(test)] const JSON_CONTENT_TYPE: &str = "application/json"; -pub use client::{ClientConfig, RetryConfig, TimeoutConfig}; +pub use client::{ClientConfig, RetryConfig, TimeoutConfig, TlsConfig}; pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder}; diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index a5fcfd1f..14077111 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -15,6 +15,19 @@ use crate::remote::retry::{ResolvedRetryConfig, RetryCounter}; const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id"); +/// Configuration for TLS/mTLS settings. +#[derive(Clone, Debug, Default)] +pub struct TlsConfig { + /// Path to the client certificate file (PEM format) + pub cert_file: Option, + /// Path to the client private key file (PEM format) + pub key_file: Option, + /// Path to the CA certificate file for server verification (PEM format) + pub ssl_ca_cert: Option, + /// Whether to verify the hostname in the server's certificate + pub assert_hostname: bool, +} + /// Configuration for the LanceDB Cloud HTTP client. #[derive(Clone, Debug)] pub struct ClientConfig { @@ -28,6 +41,8 @@ pub struct ClientConfig { /// The delimiter to use when constructing object identifiers. /// If not default, passes as query parameter. pub id_delimiter: Option, + /// TLS configuration for mTLS support + pub tls_config: Option, } impl Default for ClientConfig { @@ -38,6 +53,7 @@ impl Default for ClientConfig { user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(), extra_headers: HashMap::new(), id_delimiter: None, + tls_config: None, } } } @@ -245,6 +261,49 @@ impl RestfulLanceDbClient { if let Some(timeout) = timeout { client_builder = client_builder.timeout(timeout); } + + // Configure mTLS if TlsConfig is provided + if let Some(tls_config) = &client_config.tls_config { + // Load client certificate and key for mTLS + if let (Some(cert_file), Some(key_file)) = (&tls_config.cert_file, &tls_config.key_file) + { + let cert = std::fs::read(cert_file).map_err(|err| Error::Other { + message: format!("Failed to read certificate file: {}", cert_file), + source: Some(Box::new(err)), + })?; + let key = std::fs::read(key_file).map_err(|err| Error::Other { + message: format!("Failed to read key file: {}", key_file), + source: Some(Box::new(err)), + })?; + + let identity = reqwest::Identity::from_pem(&[&cert[..], &key[..]].concat()) + .map_err(|err| Error::Other { + message: "Failed to create client identity from certificate and key".into(), + source: Some(Box::new(err)), + })?; + client_builder = client_builder.identity(identity); + } + + // Load CA certificate for server verification + if let Some(ca_cert_file) = &tls_config.ssl_ca_cert { + let ca_cert = std::fs::read(ca_cert_file).map_err(|err| Error::Other { + message: format!("Failed to read CA certificate file: {}", ca_cert_file), + source: Some(Box::new(err)), + })?; + + let ca_cert = + reqwest::Certificate::from_pem(&ca_cert).map_err(|err| Error::Other { + message: "Failed to create CA certificate from PEM".into(), + source: Some(Box::new(err)), + })?; + client_builder = client_builder.add_root_certificate(ca_cert); + } + + // Configure hostname verification + client_builder = + client_builder.danger_accept_invalid_hostnames(!tls_config.assert_hostname); + } + let client = client_builder .default_headers(Self::default_headers( api_key, @@ -661,4 +720,50 @@ mod tests { Some(Duration::from_secs(120)) ); } + + #[test] + fn test_tls_config_default() { + let config = TlsConfig::default(); + assert!(config.cert_file.is_none()); + assert!(config.key_file.is_none()); + assert!(config.ssl_ca_cert.is_none()); + assert!(!config.assert_hostname); + } + + #[test] + fn test_tls_config_with_mtls() { + let tls_config = TlsConfig { + cert_file: Some("/path/to/cert.pem".to_string()), + key_file: Some("/path/to/key.pem".to_string()), + ssl_ca_cert: Some("/path/to/ca.pem".to_string()), + assert_hostname: true, + }; + + assert_eq!(tls_config.cert_file, Some("/path/to/cert.pem".to_string())); + assert_eq!(tls_config.key_file, Some("/path/to/key.pem".to_string())); + assert_eq!(tls_config.ssl_ca_cert, Some("/path/to/ca.pem".to_string())); + assert!(tls_config.assert_hostname); + } + + #[test] + fn test_client_config_with_tls() { + let tls_config = TlsConfig { + cert_file: Some("/path/to/cert.pem".to_string()), + key_file: Some("/path/to/key.pem".to_string()), + ssl_ca_cert: None, + assert_hostname: false, + }; + + let client_config = ClientConfig { + tls_config: Some(tls_config.clone()), + ..Default::default() + }; + + assert!(client_config.tls_config.is_some()); + let config_tls = client_config.tls_config.unwrap(); + assert_eq!(config_tls.cert_file, Some("/path/to/cert.pem".to_string())); + assert_eq!(config_tls.key_file, Some("/path/to/key.pem".to_string())); + assert!(config_tls.ssl_ca_cert.is_none()); + assert!(!config_tls.assert_hostname); + } }