feat: make tls certificates/keys reloadable (part 1) (#3335)

* feat: make tls certificates/keys reloadable (part 1)

* feat: add notify watcher for cert/key files

* test: add unit test for watcher

* fix: correct usage of watcher

* fix: skip watch when tls disabled
This commit is contained in:
Ning Sun
2024-02-26 17:37:54 +08:00
committed by GitHub
parent e859f0e67d
commit 3887d207b6
10 changed files with 309 additions and 42 deletions

70
Cargo.lock generated
View File

@@ -3585,6 +3585,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "fsevent-sys"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2"
dependencies = [
"libc",
]
[[package]]
name = "fst"
version = "0.4.7"
@@ -4369,6 +4378,26 @@ dependencies = [
"snafu",
]
[[package]]
name = "inotify"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff"
dependencies = [
"bitflags 1.3.2",
"inotify-sys",
"libc",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]]
name = "instant"
version = "0.1.12"
@@ -4566,6 +4595,26 @@ dependencies = [
"indexmap 2.1.0",
]
[[package]]
name = "kqueue"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c"
dependencies = [
"kqueue-sys",
"libc",
]
[[package]]
name = "kqueue-sys"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b"
dependencies = [
"bitflags 1.3.2",
"libc",
]
[[package]]
name = "lalrpop"
version = "0.19.12"
@@ -5655,6 +5704,25 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be"
[[package]]
name = "notify"
version = "6.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d"
dependencies = [
"bitflags 2.4.1",
"crossbeam-channel",
"filetime",
"fsevent-sys",
"inotify",
"kqueue",
"libc",
"log",
"mio",
"walkdir",
"windows-sys 0.48.0",
]
[[package]]
name = "ntapi"
version = "0.4.1"
@@ -8995,6 +9063,7 @@ dependencies = [
"lazy_static",
"mime_guess",
"mysql_async",
"notify",
"once_cell",
"openmetrics-parser",
"opensrv-mysql",
@@ -9027,6 +9096,7 @@ dependencies = [
"sql",
"strum 0.25.0",
"table",
"tempfile",
"tikv-jemalloc-ctl",
"tokio",
"tokio-postgres",

View File

@@ -18,7 +18,6 @@ use std::sync::Arc;
use auth::UserProviderRef;
use common_base::Plugins;
use common_runtime::Builder as RuntimeBuilder;
use servers::error::InternalIoSnafu;
use servers::grpc::builder::GrpcServerBuilder;
use servers::grpc::greptime_handler::GreptimeRequestHandler;
use servers::grpc::{GrpcServer, GrpcServerConfig};
@@ -30,6 +29,7 @@ use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::ServerSqlQueryHandlerAdapter;
use servers::server::{Server, ServerHandlers};
use servers::tls::{watch_tls_config, ReloadableTlsServerConfig};
use snafu::ResultExt;
use crate::error::{self, Result, StartServerSnafu};
@@ -195,6 +195,12 @@ where
let opts = &opts.mysql;
let mysql_addr = parse_addr(&opts.addr)?;
let tls_server_config = Arc::new(
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
);
watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
let mysql_io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(opts.runtime_size)
@@ -210,11 +216,7 @@ where
)),
Arc::new(MysqlSpawnConfig::new(
opts.tls.should_force_tls(),
opts.tls
.setup()
.context(InternalIoSnafu)
.context(StartServerSnafu)?
.map(Arc::new),
tls_server_config,
opts.reject_no_database.unwrap_or(false),
)),
);
@@ -226,6 +228,12 @@ where
let opts = &opts.postgres;
let pg_addr = parse_addr(&opts.addr)?;
let tls_server_config = Arc::new(
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
);
watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
let pg_io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(opts.runtime_size)
@@ -236,7 +244,8 @@ where
let pg_server = Box::new(PostgresServer::new(
ServerSqlQueryHandlerAdapter::arc(instance.clone()),
opts.tls.clone(),
opts.tls.should_force_tls(),
tls_server_config,
pg_io_runtime,
user_provider.clone(),
)) as Box<dyn Server>;

View File

@@ -59,6 +59,7 @@ influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", bran
itertools.workspace = true
lazy_static.workspace = true
mime_guess = "2.0"
notify = "6.1"
once_cell.workspace = true
openmetrics-parser = "0.4"
opensrv-mysql = "0.7.0"
@@ -121,6 +122,7 @@ script = { workspace = true, features = ["python"] }
serde_json.workspace = true
session = { workspace = true, features = ["testing"] }
table.workspace = true
tempfile = "3.0.0"
tokio-postgres = "0.7"
tokio-postgres-rustls = "0.11"
tokio-test = "0.4"

View File

@@ -441,6 +441,12 @@ pub enum Error {
"Invalid parameter, physical_table is not expected when metric engine is disabled"
))]
UnexpectedPhysicalTable { location: Location },
#[snafu(display("Failed to initialize a watcher for file"))]
FileWatch {
#[snafu(source)]
error: notify::Error,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -462,7 +468,8 @@ impl ErrorExt for Error {
| CatalogError { .. }
| GrpcReflectionService { .. }
| BuildHttpResponse { .. }
| Arrow { .. } => StatusCode::Internal,
| Arrow { .. }
| FileWatch { .. } => StatusCode::Internal,
UnsupportedDataType { .. } => StatusCode::Unsupported,

View File

@@ -33,6 +33,7 @@ use crate::error::{Error, Result};
use crate::mysql::handler::MysqlInstanceShim;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::ReloadableTlsServerConfig;
// Default size of ResultSet write buffer: 100KB
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
@@ -68,7 +69,7 @@ impl MysqlSpawnRef {
pub struct MysqlSpawnConfig {
// tls config
force_tls: bool,
tls: Option<Arc<ServerConfig>>,
tls: Arc<ReloadableTlsServerConfig>,
// other shim config
reject_no_database: bool,
}
@@ -76,7 +77,7 @@ pub struct MysqlSpawnConfig {
impl MysqlSpawnConfig {
pub fn new(
force_tls: bool,
tls: Option<Arc<ServerConfig>>,
tls: Arc<ReloadableTlsServerConfig>,
reject_no_database: bool,
) -> MysqlSpawnConfig {
MysqlSpawnConfig {
@@ -87,7 +88,7 @@ impl MysqlSpawnConfig {
}
fn tls(&self) -> Option<Arc<ServerConfig>> {
self.tls.clone()
self.tls.get_server_config()
}
}

View File

@@ -29,19 +29,20 @@ use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
use crate::error::Result;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::TlsOption;
use crate::tls::ReloadableTlsServerConfig;
pub struct PostgresServer {
base_server: BaseTcpServer,
make_handler: Arc<MakePostgresServerHandler>,
tls: TlsOption,
tls_server_config: Arc<ReloadableTlsServerConfig>,
}
impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(
query_handler: ServerSqlQueryHandlerRef,
tls: TlsOption,
force_tls: bool,
tls_server_config: Arc<ReloadableTlsServerConfig>,
io_runtime: Arc<Runtime>,
user_provider: Option<UserProviderRef>,
) -> PostgresServer {
@@ -49,14 +50,14 @@ impl PostgresServer {
MakePostgresServerHandlerBuilder::default()
.query_handler(query_handler.clone())
.user_provider(user_provider.clone())
.force_tls(tls.should_force_tls())
.force_tls(force_tls)
.build()
.unwrap(),
);
PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
make_handler,
tls,
tls_server_config,
}
}
@@ -64,12 +65,16 @@ impl PostgresServer {
&self,
io_runtime: Arc<Runtime>,
accepting_stream: AbortableStream,
tls_acceptor: Option<Arc<TlsAcceptor>>,
) -> impl Future<Output = ()> {
let handler_maker = self.make_handler.clone();
let tls_server_config = self.tls_server_config.clone();
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let tls_acceptor = tls_acceptor.clone();
let tls_acceptor = tls_server_config
.get_server_config()
.map(|server_config| Arc::new(TlsAcceptor::from(server_config)));
let handler_maker = handler_maker.clone();
async move {
@@ -119,14 +124,8 @@ impl Server for PostgresServer {
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;
debug!("Starting PostgreSQL with TLS option: {:?}", self.tls);
let tls_acceptor = self
.tls
.setup()?
.map(|server_conf| Arc::new(TlsAcceptor::from(Arc::new(server_conf))));
let io_runtime = self.base_server.io_runtime();
let join_handle = common_runtime::spawn_read(self.accept(io_runtime, stream, tls_acceptor));
let join_handle = common_runtime::spawn_read(self.accept(io_runtime, stream));
self.base_server.start_with(join_handle).await?;
Ok(addr)

View File

@@ -13,14 +13,23 @@
// limitations under the License.
use std::fs::File;
use std::io::{BufReader, Error, ErrorKind};
use std::io::{BufReader, Error as IoError, ErrorKind};
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, RwLock};
use common_telemetry::{error, info};
use notify::{EventKind, RecursiveMode, Watcher};
use rustls::ServerConfig;
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use strum::EnumString;
use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
/// TlsMode is used for Mysql and Postgres server start up.
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
#[serde(rename_all = "snake_case")]
@@ -73,27 +82,38 @@ impl TlsOption {
tls_option
}
pub fn setup(&self) -> Result<Option<ServerConfig>, Error> {
pub fn setup(&self) -> Result<Option<ServerConfig>> {
if let TlsMode::Disable = self.mode {
return Ok(None);
}
let cert = certs(&mut BufReader::new(File::open(&self.cert_path)?))
.collect::<Result<Vec<CertificateDer>, Error>>()?;
let cert = certs(&mut BufReader::new(
File::open(&self.cert_path).context(InternalIoSnafu)?,
))
.collect::<std::result::Result<Vec<CertificateDer>, IoError>>()
.context(InternalIoSnafu)?;
let key = {
let mut pkcs8 = pkcs8_private_keys(&mut BufReader::new(File::open(&self.key_path)?))
.map(|key| key.map(PrivateKeyDer::from))
.collect::<Result<Vec<PrivateKeyDer>, Error>>()?;
let mut pkcs8 = pkcs8_private_keys(&mut BufReader::new(
File::open(&self.key_path).context(InternalIoSnafu)?,
))
.map(|key| key.map(PrivateKeyDer::from))
.collect::<std::result::Result<Vec<PrivateKeyDer>, IoError>>()
.context(InternalIoSnafu)?;
if !pkcs8.is_empty() {
pkcs8.remove(0)
} else {
let mut rsa = rsa_private_keys(&mut BufReader::new(File::open(&self.key_path)?))
.map(|key| key.map(PrivateKeyDer::from))
.collect::<Result<Vec<PrivateKeyDer>, Error>>()?;
let mut rsa = rsa_private_keys(&mut BufReader::new(
File::open(&self.key_path).context(InternalIoSnafu)?,
))
.map(|key| key.map(PrivateKeyDer::from))
.collect::<std::result::Result<Vec<PrivateKeyDer>, IoError>>()
.context(InternalIoSnafu)?;
if !rsa.is_empty() {
rsa.remove(0)
} else {
return Err(Error::new(ErrorKind::InvalidInput, "invalid key"));
return Err(IoError::new(ErrorKind::InvalidInput, "invalid key"))
.context(InternalIoSnafu);
}
}
};
@@ -110,6 +130,104 @@ impl TlsOption {
pub fn should_force_tls(&self) -> bool {
!matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
}
pub fn cert_path(&self) -> &Path {
Path::new(&self.cert_path)
}
pub fn key_path(&self) -> &Path {
Path::new(&self.key_path)
}
}
/// A mutable container for TLS server config
///
/// This struct allows dynamic reloading of server certificates and keys
pub struct ReloadableTlsServerConfig {
tls_option: TlsOption,
config: RwLock<Option<Arc<ServerConfig>>>,
version: AtomicUsize,
}
impl ReloadableTlsServerConfig {
/// Create server config by loading configuration from `TlsOption`
pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
let server_config = tls_option.setup()?;
Ok(Self {
tls_option,
config: RwLock::new(server_config.map(Arc::new)),
version: AtomicUsize::new(0),
})
}
/// Reread server certificates and keys from file system.
pub fn reload(&self) -> Result<()> {
let server_config = self.tls_option.setup()?;
*self.config.write().unwrap() = server_config.map(Arc::new);
self.version.fetch_add(1, Ordering::Relaxed);
Ok(())
}
/// Get the server config hold by this container
pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
self.config.read().unwrap().clone()
}
/// Get associated `TlsOption`
pub fn get_tls_option(&self) -> &TlsOption {
&self.tls_option
}
/// Get version of current config
///
/// this version will auto increase when server config get reloaded.
pub fn get_version(&self) -> usize {
self.version.load(Ordering::Relaxed)
}
}
pub fn watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
if tls_server_config.get_tls_option().mode == TlsMode::Disable {
return Ok(());
}
let tls_server_config_for_watcher = tls_server_config.clone();
let (tx, rx) = channel::<notify::Result<notify::Event>>();
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu)?;
watcher
.watch(
tls_server_config.get_tls_option().cert_path(),
RecursiveMode::NonRecursive,
)
.context(FileWatchSnafu)?;
watcher
.watch(
tls_server_config.get_tls_option().key_path(),
RecursiveMode::NonRecursive,
)
.context(FileWatchSnafu)?;
std::thread::spawn(move || {
let _watcher = watcher;
while let Ok(res) = rx.recv() {
if let Ok(event) = res {
match event.kind {
EventKind::Modify(_) | EventKind::Create(_) => {
info!("Detected TLS cert/key file change: {:?}", event);
if let Err(err) = tls_server_config_for_watcher.reload() {
error!(err; "Failed to reload TLS server config");
}
}
_ => {}
}
}
}
});
Ok(())
}
#[cfg(test)]
@@ -237,4 +355,44 @@ mod tests {
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}
#[test]
fn test_tls_file_change_watch() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("serevr.crt");
let key_path = dir.path().join("server.key");
std::fs::copy("tests/ssl/server.crt", &cert_path).expect("failed to copy cert to tmpdir");
std::fs::copy("tests/ssl/server-rsa.key", &key_path).expect("failed to copy key to tmpdir");
let server_tls = TlsOption {
mode: TlsMode::Require,
cert_path: cert_path
.clone()
.into_os_string()
.into_string()
.expect("failed to convert path to string"),
key_path: key_path
.clone()
.into_os_string()
.into_string()
.expect("failed to convert path to string"),
};
let server_config = Arc::new(
ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
);
watch_tls_config(server_config.clone()).expect("failed to watch server config");
assert_eq!(0, server_config.get_version());
assert!(server_config.get_server_config().is_some());
std::fs::copy("tests/ssl/server-pkcs8.key", &key_path)
.expect("failed to copy key to tmpdir");
// waiting for async load
std::thread::sleep(std::time::Duration::from_millis(100));
assert!(server_config.get_version() > 1);
assert!(server_config.get_server_config().is_some());
}
}

View File

@@ -30,7 +30,7 @@ use rand::Rng;
use servers::error::Result;
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
use servers::server::Server;
use servers::tls::TlsOption;
use servers::tls::{ReloadableTlsServerConfig, TlsOption};
use table::test_util::MemTable;
use table::TableRef;
@@ -59,12 +59,17 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result<Box<dyn S
provider.set_authorization_info(auth_info);
}
let tls_server_config = Arc::new(
ReloadableTlsServerConfig::try_new(opts.tls.clone())
.expect("Failed to load certificates and keys"),
);
Ok(MysqlServer::create_server(
io_runtime,
Arc::new(MysqlSpawnRef::new(query_handler, Some(Arc::new(provider)))),
Arc::new(MysqlSpawnConfig::new(
opts.tls.should_force_tls(),
opts.tls.setup()?.map(Arc::new),
tls_server_config,
opts.reject_no_database,
)),
))

View File

@@ -29,7 +29,7 @@ use rustls_pki_types::{CertificateDer, ServerName};
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
use servers::tls::TlsOption;
use servers::tls::{ReloadableTlsServerConfig, TlsOption};
use table::test_util::MemTable;
use table::TableRef;
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
@@ -60,9 +60,15 @@ fn create_postgres_server(
None
};
let tls_server_config = Arc::new(
ReloadableTlsServerConfig::try_new(tls.clone())
.expect("Failed to load certificates and keys"),
);
Ok(Box::new(PostgresServer::new(
instance,
tls,
tls.should_force_tls(),
tls_server_config,
io_runtime,
user_provider,
)))

View File

@@ -50,6 +50,7 @@ use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
use servers::server::Server;
use servers::tls::ReloadableTlsServerConfig;
use servers::Mode;
use session::context::QueryContext;
@@ -568,7 +569,10 @@ pub async fn setup_mysql_server_with_user_provider(
)),
Arc::new(MysqlSpawnConfig::new(
false,
opts.tls.setup().unwrap().map(Arc::new),
Arc::new(
ReloadableTlsServerConfig::try_new(opts.tls.clone())
.expect("Failed to load certificates and keys"),
),
opts.reject_no_database.unwrap_or(false),
)),
));
@@ -614,9 +618,15 @@ pub async fn setup_pg_server_with_user_provider(
addr: fe_pg_addr.clone(),
..Default::default()
};
let tls_server_config = Arc::new(
ReloadableTlsServerConfig::try_new(opts.tls.clone())
.expect("Failed to load certificates and keys"),
);
let fe_pg_server = Arc::new(Box::new(PostgresServer::new(
ServerSqlQueryHandlerAdapter::arc(fe_instance_ref),
opts.tls.clone(),
opts.tls.should_force_tls(),
tls_server_config,
runtime,
user_provider,
)) as Box<dyn Server>);