feat: mysql and pg server support tls (#641)

* feat: mysql and pg server support tls

* chore: replace opensrv-mysql to original

* chore: TlsOption is required but supply default value

* feat: mysql server support force tls

* chore: move TlsOption to servers

* test: mysql server disable / prefer / required tls mode

* test: pg server disable / prefer / required tls mode

* chore: add doc and remove no used code

* chore: add TODO and restore cargo linker config
This commit is contained in:
SSebo
2022-11-30 12:46:15 +08:00
committed by GitHub
parent a17dcbc511
commit 68c2de8e45
18 changed files with 674 additions and 45 deletions

View File

@@ -14,20 +14,27 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, SystemTime};
use common_runtime::Builder as RuntimeBuilder;
use rand::rngs::StdRng;
use rand::Rng;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, Error, ServerName};
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
use servers::tls::TlsOption;
use table::test_util::MemTable;
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
use crate::create_testing_sql_query_handler;
fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Server>> {
fn create_postgres_server(
table: MemTable,
check_pwd: bool,
tls: Arc<TlsOption>,
) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
@@ -39,6 +46,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
Ok(Box::new(PostgresServer::new(
query_handler,
check_pwd,
tls,
io_runtime,
)))
}
@@ -47,7 +55,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
pub async fn test_start_postgres_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false)?;
let pg_server = create_postgres_server(table, false, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = pg_server.start(listening).await;
assert!(result.is_ok());
@@ -73,7 +81,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
let table = MemTable::default_numbers_table();
let postgres_server = create_postgres_server(table, with_pwd)?;
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
let result = postgres_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -88,7 +96,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, with_pwd).await {
match create_plain_connection(server_port, with_pwd).await {
Ok(connection) => {
match connection
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
@@ -132,7 +140,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false)?;
let pg_server = create_postgres_server(table, false, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
@@ -144,7 +152,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut client = create_connection(server_port, false).await.unwrap();
let mut client = create_plain_connection(server_port, false).await.unwrap();
for _k in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
@@ -165,7 +173,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
// 1/100 chance to reconnect
let should_recreate_conn = expected == 1;
if should_recreate_conn {
client = create_connection(server_port, false).await.unwrap();
client = create_plain_connection(server_port, false).await.unwrap();
}
}
expect_executed_queries_per_worker
@@ -179,7 +187,109 @@ async fn test_query_pg_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result<Client, PgError> {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_secure_prefer_client_plain() -> Result<()> {
common_telemetry::init_default_ut_logging();
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Prefer,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = false;
do_simple_query(server_tls, client_tls).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_secure_require_client_plain() -> Result<()> {
common_telemetry::init_default_ut_logging();
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
let r = create_plain_connection(server_port, false).await;
assert!(r.is_err());
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_secure_require_client_secure() -> Result<()> {
common_telemetry::init_default_ut_logging();
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = true;
do_simple_query(server_tls, client_tls).await?;
Ok(())
}
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
if !client_tls {
let client = create_plain_connection(server_port, false).await.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_ok());
} else {
let client = create_secure_connection(server_port, false).await.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_ok());
}
Ok(())
}
async fn create_secure_connection(
port: u16,
with_pwd: bool,
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
port
)
} else {
format!("host=127.0.0.1 port={} connect_timeout=2", port)
};
let mut config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth();
config
.dangerous()
.set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
let (client, conn) = tokio_postgres::connect(&url, tls).await.expect("connect");
tokio::spawn(conn);
Ok(client)
}
async fn create_plain_connection(
port: u16,
with_pwd: bool,
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
@@ -203,3 +313,18 @@ fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
fn unwrap_results(resp: &[SimpleQueryMessage]) -> Vec<&str> {
resp.iter().filter_map(|m| resolve_result(m, 0)).collect()
}
struct AcceptAllVerifier {}
impl ServerCertVerifier for AcceptAllVerifier {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> std::result::Result<ServerCertVerified, Error> {
Ok(ServerCertVerified::assertion())
}
}