From cee9c726d2a6f38b9d692dcc134ebc493ebce6e1 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Fri, 31 Mar 2023 21:19:20 +0300 Subject: [PATCH] Implement proper parsing --- proxy/src/certs.rs | 69 ++++++++++++++++++++++++++++++++++++++++++---- proxy/src/main.rs | 12 +++++--- 2 files changed, 71 insertions(+), 10 deletions(-) diff --git a/proxy/src/certs.rs b/proxy/src/certs.rs index 9d5104c00f..d7f96b87df 100644 --- a/proxy/src/certs.rs +++ b/proxy/src/certs.rs @@ -2,17 +2,20 @@ use rustls::{ server::{ClientHello, ResolvesServerCert}, sign::CertifiedKey, }; -use std::sync::Arc; +use std::{io, sync::Arc}; pub mod config { + use super::*; use serde::Deserialize; use std::path::Path; - /// TODO: explain. + /// Collection of TLS-related configurations of virtual proxy servers. #[derive(Debug, Default, Clone, Deserialize)] + #[serde(transparent)] pub struct TlsServers(Vec); impl TlsServers { + /// Load [`Self`] config from a file. pub fn from_config_file(path: impl AsRef) -> anyhow::Result { let config = serde_dhall::from_file(path).parse()?; Ok(config) @@ -26,22 +29,76 @@ pub mod config { } } + #[derive(Debug, Clone, Deserialize)] + #[serde(transparent)] + pub struct TlsCert( + /// The wrapped rustls certificate. + #[serde(deserialize_with = "deserialize_certs")] + pub Vec, + ); + + fn deserialize_certs<'de, D>(des: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + let text = String::deserialize(des)?; + parse_certs(&mut text.as_bytes()).map_err(serde::de::Error::custom) + } + + #[derive(Debug, Clone, Deserialize)] + #[serde(transparent)] + pub struct TlsKey( + /// The wrapped rustls private key. + #[serde(deserialize_with = "deserialize_key")] + pub rustls::PrivateKey, + ); + + fn deserialize_key<'de, D>(des: D) -> Result + where + D: serde::Deserializer<'de>, + { + let text = String::deserialize(des)?; + parse_key(&mut text.as_bytes()).map_err(serde::de::Error::custom) + } + /// TODO: explain. #[derive(Debug, Clone, Deserialize)] pub struct TlsServer { - server_name: Box, - certificate: Box, - private_key: Box, + pub certificate: TlsCert, + pub private_key: TlsKey, } } +fn parse_certs(buf: &mut impl io::BufRead) -> io::Result> { + let chain = rustls_pemfile::certs(buf)? + .into_iter() + .map(rustls::Certificate) + .collect(); + + Ok(chain) +} + +fn parse_key(buf: &mut impl io::BufRead) -> io::Result { + let mut keys = rustls_pemfile::pkcs8_private_keys(buf)?; + + // We expect to see only 1 key. + if keys.len() != 1 { + return Err(io::Error::new( + io::ErrorKind::Other, + "there should be exactly one TLS key in buffer", + )); + } + + Ok(rustls::PrivateKey(keys.pop().unwrap())) +} + pub struct CertResolver { resolver: rustls::server::ResolvesServerCertUsingSni, } impl CertResolver { pub fn new() -> Self { - todo!() + todo!("CertResolver ctor") } } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 90e96238f8..403526ded3 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -133,12 +133,13 @@ fn build_tls_config(args: &clap::ArgMatches) -> anyhow::Result let tls_config = args.get_one::("tls-config"); let main = tls_config.map(TlsServers::from_config_file).transpose()?; + tracing::info!(?main, "config"); let tls_cert = args.get_one::("tls-cert"); let tls_key = args.get_one::("tls-key"); let aux = match (tls_cert, tls_key) { - (Some(key_path), Some(cert_path)) => todo!("implement legacy TLS setup"), + (Some(_key), Some(_cert)) => todo!("implement legacy TLS setup"), (None, None) => None::<()>, _ => bail!("either both or neither tls-key and tls-cert must be specified"), }; @@ -264,19 +265,22 @@ fn cli() -> clap::Command { .short('k') .long("tls-key") .alias("ssl-key") // backwards compatibility - .help("path to TLS key for client postgres connections"), + .help("path to TLS key for client postgres connections") + .value_parser(clap::builder::PathBufValueParser::new()), ) .arg( Arg::new("tls-cert") .short('c') .long("tls-cert") .alias("ssl-cert") // backwards compatibility - .help("path to TLS cert for client postgres connections"), + .help("path to TLS cert for client postgres connections") + .value_parser(clap::builder::PathBufValueParser::new()), ) .arg( Arg::new("tls-config") .long("tls-config") - .help("path to the TLS config file"), + .help("path to the TLS config file") + .value_parser(clap::builder::PathBufValueParser::new()), ) .arg( Arg::new("metric-collection-endpoint")