mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-23 00:10:38 +00:00
feat: mysql and pg server support tls (#641)
* feat: mysql and pg server support tls * chore: replace opensrv-mysql to original * chore: TlsOption is required but supply default value * feat: mysql server support force tls * chore: move TlsOption to servers * test: mysql server disable / prefer / required tls mode * test: pg server disable / prefer / required tls mode * chore: add doc and remove no used code * chore: add TODO and restore cargo linker config
This commit is contained in:
@@ -14,20 +14,27 @@
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::{Certificate, Error, ServerName};
|
||||
use servers::error::Result;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
|
||||
|
||||
use crate::create_testing_sql_query_handler;
|
||||
|
||||
fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Server>> {
|
||||
fn create_postgres_server(
|
||||
table: MemTable,
|
||||
check_pwd: bool,
|
||||
tls: Arc<TlsOption>,
|
||||
) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
@@ -39,6 +46,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
|
||||
Ok(Box::new(PostgresServer::new(
|
||||
query_handler,
|
||||
check_pwd,
|
||||
tls,
|
||||
io_runtime,
|
||||
)))
|
||||
}
|
||||
@@ -47,7 +55,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
|
||||
pub async fn test_start_postgres_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table, false)?;
|
||||
let pg_server = create_postgres_server(table, false, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = pg_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -73,7 +81,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let postgres_server = create_postgres_server(table, with_pwd)?;
|
||||
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -88,7 +96,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port, with_pwd).await {
|
||||
match create_plain_connection(server_port, with_pwd).await {
|
||||
Ok(connection) => {
|
||||
match connection
|
||||
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -132,7 +140,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table, false)?;
|
||||
let pg_server = create_postgres_server(table, false, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
@@ -144,7 +152,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut client = create_connection(server_port, false).await.unwrap();
|
||||
let mut client = create_plain_connection(server_port, false).await.unwrap();
|
||||
|
||||
for _k in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
@@ -165,7 +173,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
// 1/100 chance to reconnect
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
client = create_connection(server_port, false).await.unwrap();
|
||||
client = create_plain_connection(server_port, false).await.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -179,7 +187,109 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result<Client, PgError> {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_secure_prefer_client_plain() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Prefer,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = false;
|
||||
do_simple_query(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_secure_require_client_plain() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Require,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
let pg_server = create_postgres_server(table, false, server_tls)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let r = create_plain_connection(server_port, false).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_secure_require_client_secure() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Require,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_simple_query(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
let pg_server = create_postgres_server(table, false, server_tls)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
if !client_tls {
|
||||
let client = create_plain_connection(server_port, false).await.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_ok());
|
||||
} else {
|
||||
let client = create_secure_connection(server_port, false).await.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_secure_connection(
|
||||
port: u16,
|
||||
with_pwd: bool,
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
|
||||
port
|
||||
)
|
||||
} else {
|
||||
format!("host=127.0.0.1 port={} connect_timeout=2", port)
|
||||
};
|
||||
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(rustls::RootCertStore::empty())
|
||||
.with_no_client_auth();
|
||||
config
|
||||
.dangerous()
|
||||
.set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
|
||||
|
||||
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
|
||||
let (client, conn) = tokio_postgres::connect(&url, tls).await.expect("connect");
|
||||
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn create_plain_connection(
|
||||
port: u16,
|
||||
with_pwd: bool,
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
|
||||
@@ -203,3 +313,18 @@ fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
|
||||
fn unwrap_results(resp: &[SimpleQueryMessage]) -> Vec<&str> {
|
||||
resp.iter().filter_map(|m| resolve_result(m, 0)).collect()
|
||||
}
|
||||
|
||||
struct AcceptAllVerifier {}
|
||||
impl ServerCertVerifier for AcceptAllVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> std::result::Result<ServerCertVerified, Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user