feat: mysql and pg server support tls (#641)

* feat: mysql and pg server support tls

* chore: replace opensrv-mysql to original

* chore: TlsOption is required but supply default value

* feat: mysql server support force tls

* chore: move TlsOption to servers

* test: mysql server disable / prefer / required tls mode

* test: pg server disable / prefer / required tls mode

* chore: add doc and remove no used code

* chore: add TODO and restore cargo linker config
This commit is contained in:
SSebo
2022-11-30 12:46:15 +08:00
committed by GitHub
parent a17dcbc511
commit 68c2de8e45
18 changed files with 674 additions and 45 deletions

29
Cargo.lock generated
View File

@@ -2218,6 +2218,7 @@ dependencies = [
"openmetrics-parser",
"prost 0.11.0",
"query",
"rustls",
"serde",
"serde_json",
"servers",
@@ -3717,16 +3718,18 @@ dependencies = [
[[package]]
name = "opensrv-mysql"
version = "0.2.0"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4c24c12fd688cb5aa5b1a54c6ccb2e30fb9b5132debb0e89fcb432b3f73db8f"
checksum = "ac5d68ae914b1317d874ce049e52d386b1209d8835d4e6e094f2e90bfb49eccc"
dependencies = [
"async-trait",
"byteorder",
"chrono",
"mysql_common",
"nom",
"pin-project-lite",
"tokio",
"tokio-rustls",
]
[[package]]
@@ -4923,9 +4926,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.20.6"
version = "0.20.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033"
checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c"
dependencies = [
"log",
"ring",
@@ -5507,6 +5510,8 @@ dependencies = [
"query",
"rand 0.8.5",
"regex",
"rustls",
"rustls-pemfile 1.0.1",
"schemars",
"script",
"serde",
@@ -5516,6 +5521,8 @@ dependencies = [
"table",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tokio-stream",
"tokio-test",
"tonic",
@@ -6418,6 +6425,20 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "tokio-postgres-rustls"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "606f2b73660439474394432239c82249c0d45eb5f23d91f401be1e33590444a7"
dependencies = [
"futures",
"ring",
"rustls",
"tokio",
"tokio-postgres",
"tokio-rustls",
]
[[package]]
name = "tokio-rustls"
version = "0.23.4"

View File

@@ -11,3 +11,4 @@ common-error = { path = "../error" }
paste = "1.0"
serde = { version = "1.0", features = ["derive"] }
snafu = { version = "0.7", features = ["backtraces"] }

View File

@@ -62,6 +62,7 @@ impl Services {
Some(MysqlServer::create_server(
instance.clone(),
mysql_io_runtime,
Default::default(),
))
}
};

View File

@@ -45,6 +45,7 @@ sql = { path = "../sql" }
store-api = { path = "../store-api" }
table = { path = "../table" }
tokio = { version = "1.18", features = ["full"] }
rustls = "0.20"
[dev-dependencies]
datanode = { path = "../datanode" }

View File

@@ -12,12 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use servers::tls::TlsOption;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MysqlOptions {
pub addr: String,
pub runtime_size: usize,
#[serde(default = "Default::default")]
pub tls: Arc<TlsOption>,
}
impl Default for MysqlOptions {
@@ -25,6 +30,7 @@ impl Default for MysqlOptions {
Self {
addr: "127.0.0.1:4002".to_string(),
runtime_size: 2,
tls: Arc::new(TlsOption::default()),
}
}
}

View File

@@ -12,13 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use servers::tls::TlsOption;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PostgresOptions {
pub addr: String,
pub runtime_size: usize,
pub check_pwd: bool,
#[serde(default = "Default::default")]
pub tls: Arc<TlsOption>,
}
impl Default for PostgresOptions {
@@ -27,6 +32,7 @@ impl Default for PostgresOptions {
addr: "127.0.0.1:4003".to_string(),
runtime_size: 2,
check_pwd: false,
tls: Default::default(),
}
}
}

View File

@@ -69,7 +69,8 @@ impl Services {
.context(error::RuntimeResourceSnafu)?,
);
let mysql_server = MysqlServer::create_server(instance.clone(), mysql_io_runtime);
let mysql_server =
MysqlServer::create_server(instance.clone(), mysql_io_runtime, opts.tls.clone());
Some((mysql_server, mysql_addr))
} else {
@@ -90,6 +91,7 @@ impl Services {
let pg_server = Box::new(PostgresServer::new(
instance.clone(),
opts.check_pwd,
opts.tls.clone(),
pg_io_runtime,
)) as Box<dyn Server>;

View File

@@ -30,7 +30,7 @@ metrics = "0.20"
num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.2"
opensrv-mysql = "0.3"
pgwire = "0.5"
prost = "0.11"
regex = "1.6"
@@ -47,6 +47,9 @@ tonic = "0.8"
tonic-reflection = "0.5"
tower = { version = "0.4", features = ["full"] }
tower-http = { version = "0.3", features = ["full"] }
tokio-rustls = "0.23"
rustls = "0.20"
rustls-pemfile = "1.0"
[dev-dependencies]
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
@@ -56,6 +59,8 @@ mysql_async = { git = "https://github.com/Morranto/mysql_async.git", rev = "127b
query = { path = "../query" }
rand = "0.8"
script = { path = "../script", features = ["python"] }
serde_json = "1.0"
table = { path = "../table" }
tokio-postgres = "0.7"
tokio-postgres-rustls = "0.9"
tokio-test = "0.4"

View File

@@ -192,6 +192,9 @@ pub enum Error {
err_msg: String,
backtrace: Backtrace,
},
#[snafu(display("Tls is required for {}, plain connection is rejected", server))]
TlsRequired { server: String },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -234,6 +237,7 @@ impl ErrorExt for Error {
InfluxdbLinesWrite { source, .. } => source.status_code(),
Hyper { .. } => StatusCode::Unknown,
TlsRequired { .. } => StatusCode::Unknown,
StartFrontend { source, .. } => source.status_code(),
}
}

View File

@@ -28,6 +28,8 @@ pub mod postgres;
pub mod prometheus;
pub mod query_handler;
pub mod server;
pub mod tls;
mod shutdown;
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]

View File

@@ -20,15 +20,19 @@ use async_trait::async_trait;
use common_runtime::Runtime;
use common_telemetry::logging::{error, info};
use futures::StreamExt;
use opensrv_mysql::AsyncMysqlIntermediary;
use opensrv_mysql::{
plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
};
use tokio;
use tokio::io::BufWriter;
use tokio::net::TcpStream;
use tokio_rustls::rustls::ServerConfig;
use crate::error::Result;
use crate::error::{Error, Result};
use crate::mysql::handler::MysqlInstanceShim;
use crate::query_handler::SqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::TlsOption;
// Default size of ResultSet write buffer: 100KB
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
@@ -36,16 +40,19 @@ const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
pub struct MysqlServer {
base_server: BaseTcpServer,
query_handler: SqlQueryHandlerRef,
tls: Arc<TlsOption>,
}
impl MysqlServer {
pub fn create_server(
query_handler: SqlQueryHandlerRef,
io_runtime: Arc<Runtime>,
tls: Arc<TlsOption>,
) -> Box<dyn Server> {
Box::new(MysqlServer {
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
query_handler,
tls,
})
}
@@ -53,16 +60,22 @@ impl MysqlServer {
&self,
io_runtime: Arc<Runtime>,
stream: AbortableStream,
tls_conf: Option<Arc<ServerConfig>>,
) -> impl Future<Output = ()> {
let query_handler = self.query_handler.clone();
let force_tls = self.tls.should_force_tls();
stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let query_handler = query_handler.clone();
let tls_conf = tls_conf.clone();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
if let Err(error) = Self::handle(io_stream, io_runtime, query_handler).await
if let Err(error) =
Self::handle(io_stream, io_runtime, query_handler, tls_conf, force_tls)
.await
{
error!(error; "Unexpected error when handling TcpStream");
};
@@ -76,22 +89,49 @@ impl MysqlServer {
stream: TcpStream,
io_runtime: Arc<Runtime>,
query_handler: SqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
) -> Result<()> {
info!("MySQL connection coming from: {}", stream.peer_addr()?);
let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
let (r, w) = stream.into_split();
let w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
io_runtime.spawn(async move {
io_runtime .spawn(async move {
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
if let Err(e) = AsyncMysqlIntermediary::run_on(shim, r, w).await {
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
}
});
if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls).await {
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
}
});
Ok(())
}
async fn do_handle(
stream: TcpStream,
query_handler: SqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
) -> Result<()> {
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
let (mut r, w) = stream.into_split();
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
let ops = IntermediaryOptions::default();
let (client_tls, init_params) =
AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf).await?;
if force_tls && !client_tls {
return Err(Error::TlsRequired {
server: "mysql".to_owned(),
});
}
match tls_conf {
Some(tls_conf) if client_tls => {
secure_run_with_options(shim, w, ops, tls_conf, init_params).await
}
_ => plain_run_with_options(shim, w, ops, init_params).await,
}
}
}
#[async_trait]
@@ -104,7 +144,10 @@ impl Server for MysqlServer {
let (stream, addr) = self.base_server.bind(listening).await?;
let io_runtime = self.base_server.io_runtime();
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
let tls_conf = self.tls.setup()?.map(Arc::new);
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf));
self.base_server.start_with(join_handle).await?;
Ok(addr)
}

View File

@@ -63,14 +63,16 @@ pub struct PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters,
with_pwd: bool,
force_tls: bool,
}
impl PgAuthStartupHandler {
pub fn new(with_pwd: bool) -> Self {
pub fn new(with_pwd: bool, force_tls: bool) -> Self {
PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters::new(),
with_pwd,
force_tls,
}
}
}
@@ -89,6 +91,20 @@ impl StartupHandler for PgAuthStartupHandler {
{
match message {
PgWireFrontendMessage::Startup(ref startup) => {
if !client.is_secure() && self.force_tls {
let error_info = ErrorInfo::new(
"FATAL".to_owned(),
"28000".to_owned(),
"No encryption".to_owned(),
);
let error = ErrorResponse::from(error_info);
client
.feed(PgWireBackendMessage::ErrorResponse(error))
.await?;
client.close().await?;
return Ok(());
}
auth::save_startup_parameters_to_metadata(client, startup);
if self.with_pwd {
client.set_state(PgWireConnectionState::AuthenticationInProgress);

View File

@@ -22,17 +22,20 @@ use common_telemetry::logging::error;
use futures::StreamExt;
use pgwire::tokio::process_socket;
use tokio;
use tokio_rustls::TlsAcceptor;
use crate::error::Result;
use crate::postgres::auth_handler::PgAuthStartupHandler;
use crate::postgres::handler::PostgresServerHandler;
use crate::query_handler::SqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::TlsOption;
pub struct PostgresServer {
base_server: BaseTcpServer,
auth_handler: Arc<PgAuthStartupHandler>,
query_handler: Arc<PostgresServerHandler>,
tls: Arc<TlsOption>,
}
impl PostgresServer {
@@ -40,14 +43,17 @@ impl PostgresServer {
pub fn new(
query_handler: SqlQueryHandlerRef,
check_pwd: bool,
tls: Arc<TlsOption>,
io_runtime: Arc<Runtime>,
) -> PostgresServer {
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
let startup_handler = Arc::new(PgAuthStartupHandler::new(check_pwd));
let startup_handler =
Arc::new(PgAuthStartupHandler::new(check_pwd, tls.should_force_tls()));
PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
auth_handler: startup_handler,
query_handler: postgres_handler,
tls,
}
}
@@ -55,6 +61,7 @@ impl PostgresServer {
&self,
io_runtime: Arc<Runtime>,
accepting_stream: AbortableStream,
tls_acceptor: Option<Arc<TlsAcceptor>>,
) -> impl Future<Output = ()> {
let auth_handler = self.auth_handler.clone();
let query_handler = self.query_handler.clone();
@@ -63,6 +70,7 @@ impl PostgresServer {
let io_runtime = io_runtime.clone();
let auth_handler = auth_handler.clone();
let query_handler = query_handler.clone();
let tls_acceptor = tls_acceptor.clone();
async move {
match tcp_stream {
@@ -70,7 +78,7 @@ impl PostgresServer {
Ok(io_stream) => {
io_runtime.spawn(process_socket(
io_stream,
None,
tls_acceptor.clone(),
auth_handler.clone(),
query_handler.clone(),
query_handler.clone(),
@@ -91,8 +99,14 @@ impl Server for PostgresServer {
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;
let tls_acceptor = self
.tls
.setup()?
.map(|server_conf| Arc::new(TlsAcceptor::from(Arc::new(server_conf))));
let io_runtime = self.base_server.io_runtime();
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_acceptor));
self.base_server.start_with(join_handle).await?;
Ok(addr)
}

177
src/servers/src/tls.rs Normal file
View File

@@ -0,0 +1,177 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fs::File;
use std::io::{BufReader, Error, ErrorKind};
use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys};
use serde::{Deserialize, Serialize};
/// TlsMode is used for Mysql and Postgres server start up.
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum TlsMode {
#[default]
Disable,
Prefer,
Require,
// TODO(SSebo): Implement the following 2 TSL mode described in
// ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
VerifyCa,
VerifyFull,
}
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub struct TlsOption {
pub mode: TlsMode,
#[serde(default)]
pub cert_path: String,
#[serde(default)]
pub key_path: String,
}
impl TlsOption {
pub fn setup(&self) -> Result<Option<ServerConfig>, Error> {
if let TlsMode::Disable = self.mode {
return Ok(None);
}
let cert = certs(&mut BufReader::new(File::open(&self.cert_path)?))
.map_err(|_| Error::new(ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
// TODO(SSebo): support more private key types
let key = pkcs8_private_keys(&mut BufReader::new(File::open(&self.key_path)?))
.map_err(|_| Error::new(ErrorKind::InvalidInput, "invalid key"))
.map(|mut keys| keys.drain(..).map(PrivateKey).next())?
.ok_or_else(|| Error::new(ErrorKind::InvalidInput, "invalid key"))?;
// TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
Ok(Some(config))
}
pub fn should_force_tls(&self) -> bool {
!matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tls_option_disable() {
let s = r#"
{
"mode": "disable"
}
"#;
let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(!t.should_force_tls());
assert!(matches!(t.mode, TlsMode::Disable));
assert!(t.key_path.is_empty());
assert!(t.cert_path.is_empty());
let setup = t.setup();
assert!(setup.is_ok());
let setup = setup.unwrap();
assert!(setup.is_none());
}
#[test]
fn test_tls_option_prefer() {
let s = r#"
{
"mode": "prefer",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key"
}
"#;
let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(!t.should_force_tls());
assert!(matches!(t.mode, TlsMode::Prefer));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}
#[test]
fn test_tls_option_require() {
let s = r#"
{
"mode": "require",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key"
}
"#;
let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(t.should_force_tls());
assert!(matches!(t.mode, TlsMode::Require));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}
#[test]
fn test_tls_option_verifiy_ca() {
let s = r#"
{
"mode": "verify_ca",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key"
}
"#;
let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(t.should_force_tls());
assert!(matches!(t.mode, TlsMode::VerifyCa));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}
#[test]
fn test_tls_option_verifiy_full() {
let s = r#"
{
"mode": "verify_full",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key"
}
"#;
let t: TlsOption = serde_json::from_str(s).unwrap();
assert!(t.should_force_tls());
assert!(matches!(t.mode, TlsMode::VerifyFull));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
}
}

View File

@@ -20,17 +20,19 @@ use common_recordbatch::RecordBatch;
use common_runtime::Builder as RuntimeBuilder;
use datatypes::schema::Schema;
use mysql_async::prelude::*;
use mysql_async::SslOpts;
use rand::rngs::StdRng;
use rand::Rng;
use servers::error::Result;
use servers::mysql::server::MysqlServer;
use servers::server::Server;
use servers::tls::TlsOption;
use table::test_util::MemTable;
use crate::create_testing_sql_query_handler;
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};
fn create_mysql_server(table: MemTable) -> Result<Box<dyn Server>> {
fn create_mysql_server(table: MemTable, tls: Arc<TlsOption>) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
@@ -39,14 +41,14 @@ fn create_mysql_server(table: MemTable) -> Result<Box<dyn Server>> {
.build()
.unwrap(),
);
Ok(MysqlServer::create_server(query_handler, io_runtime))
Ok(MysqlServer::create_server(query_handler, io_runtime, tls))
}
#[tokio::test]
async fn test_start_mysql_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table)?;
let mysql_server = create_mysql_server(table, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = mysql_server.start(listening).await;
assert!(result.is_ok());
@@ -65,7 +67,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table)?;
let mysql_server = create_mysql_server(table, Default::default())?;
let result = mysql_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -80,7 +82,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
for index in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, index == 1).await {
match create_connection(server_port, index == 1, false).await {
Ok(mut connection) => {
let result: u32 = connection
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -114,6 +116,63 @@ async fn test_shutdown_mysql_server() -> Result<()> {
async fn test_query_all_datatypes() -> Result<()> {
common_telemetry::init_default_ut_logging();
let server_tls = Arc::new(TlsOption::default());
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_prefer_secure_client_plain() -> Result<()> {
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Prefer,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_prefer_secure_client_secure() -> Result<()> {
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Prefer,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_require_secure_client_secure() -> Result<()> {
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_required_secure_client_plain() -> Result<()> {
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = false;
#[allow(unused)]
let TestingData {
column_schemas,
mysql_columns_def,
@@ -124,11 +183,41 @@ async fn test_query_all_datatypes() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table)?;
let mysql_server = create_mysql_server(table, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port(), false).await.unwrap();
let r = create_connection(server_addr.port(), client_tls, false).await;
assert!(r.is_err());
Ok(())
}
async fn do_test_query_all_datatypes(
server_tls: Arc<TlsOption>,
with_pwd: bool,
client_tls: bool,
) -> Result<()> {
common_telemetry::init_default_ut_logging();
let TestingData {
column_schemas,
mysql_columns_def,
columns,
mysql_text_output_rows,
} = all_datatype_testing_data();
let schema = Arc::new(Schema::new(column_schemas.clone()));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port(), client_tls, with_pwd)
.await
.unwrap();
let mut result = connection
.query_iter("SELECT * FROM all_datatypes LIMIT 3")
.await
@@ -155,7 +244,7 @@ async fn test_query_concurrently() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table)?;
let mysql_server = create_mysql_server(table, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let server_port = server_addr.port();
@@ -167,7 +256,7 @@ async fn test_query_concurrently() -> Result<()> {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut connection = create_connection(server_port, index % 2 == 0)
let mut connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
for _ in 0..expect_executed_queries_per_worker {
@@ -184,7 +273,7 @@ async fn test_query_concurrently() -> Result<()> {
let should_recreate_conn = expected == 1;
if should_recreate_conn {
connection = create_connection(server_port, index % 2 == 0)
connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
}
@@ -200,13 +289,24 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(port: u16, with_pwd: bool) -> mysql_async::Result<mysql_async::Conn> {
async fn create_connection(
port: u16,
with_pwd: bool,
ssl: bool,
) -> mysql_async::Result<mysql_async::Conn> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname("127.0.0.1")
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000));
if ssl {
let ssl_opts = SslOpts::default()
.with_danger_skip_domain_validation(true)
.with_danger_accept_invalid_certs(true);
opts = opts.ssl_opts(ssl_opts)
}
if with_pwd {
opts = opts.pass(Some("default_pwd".to_string()));
}

View File

@@ -14,20 +14,27 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, SystemTime};
use common_runtime::Builder as RuntimeBuilder;
use rand::rngs::StdRng;
use rand::Rng;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, Error, ServerName};
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
use servers::tls::TlsOption;
use table::test_util::MemTable;
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
use crate::create_testing_sql_query_handler;
fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Server>> {
fn create_postgres_server(
table: MemTable,
check_pwd: bool,
tls: Arc<TlsOption>,
) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
@@ -39,6 +46,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
Ok(Box::new(PostgresServer::new(
query_handler,
check_pwd,
tls,
io_runtime,
)))
}
@@ -47,7 +55,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
pub async fn test_start_postgres_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false)?;
let pg_server = create_postgres_server(table, false, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = pg_server.start(listening).await;
assert!(result.is_ok());
@@ -73,7 +81,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
let table = MemTable::default_numbers_table();
let postgres_server = create_postgres_server(table, with_pwd)?;
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
let result = postgres_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -88,7 +96,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, with_pwd).await {
match create_plain_connection(server_port, with_pwd).await {
Ok(connection) => {
match connection
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
@@ -132,7 +140,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false)?;
let pg_server = create_postgres_server(table, false, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
@@ -144,7 +152,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut client = create_connection(server_port, false).await.unwrap();
let mut client = create_plain_connection(server_port, false).await.unwrap();
for _k in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
@@ -165,7 +173,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
// 1/100 chance to reconnect
let should_recreate_conn = expected == 1;
if should_recreate_conn {
client = create_connection(server_port, false).await.unwrap();
client = create_plain_connection(server_port, false).await.unwrap();
}
}
expect_executed_queries_per_worker
@@ -179,7 +187,109 @@ async fn test_query_pg_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result<Client, PgError> {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_secure_prefer_client_plain() -> Result<()> {
common_telemetry::init_default_ut_logging();
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Prefer,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = false;
do_simple_query(server_tls, client_tls).await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_secure_require_client_plain() -> Result<()> {
common_telemetry::init_default_ut_logging();
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
let r = create_plain_connection(server_port, false).await;
assert!(r.is_err());
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_server_secure_require_client_secure() -> Result<()> {
common_telemetry::init_default_ut_logging();
let server_tls = Arc::new(TlsOption {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let client_tls = true;
do_simple_query(server_tls, client_tls).await?;
Ok(())
}
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
if !client_tls {
let client = create_plain_connection(server_port, false).await.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_ok());
} else {
let client = create_secure_connection(server_port, false).await.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_ok());
}
Ok(())
}
async fn create_secure_connection(
port: u16,
with_pwd: bool,
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
port
)
} else {
format!("host=127.0.0.1 port={} connect_timeout=2", port)
};
let mut config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth();
config
.dangerous()
.set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
let (client, conn) = tokio_postgres::connect(&url, tls).await.expect("connect");
tokio::spawn(conn);
Ok(client)
}
async fn create_plain_connection(
port: u16,
with_pwd: bool,
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
@@ -203,3 +313,18 @@ fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
fn unwrap_results(resp: &[SimpleQueryMessage]) -> Vec<&str> {
resp.iter().filter_map(|m| resolve_result(m, 0)).collect()
}
struct AcceptAllVerifier {}
impl ServerCertVerifier for AcceptAllVerifier {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> std::result::Result<ServerCertVerified, Error> {
Ok(ServerCertVerified::assertion())
}
}

View File

@@ -0,0 +1,77 @@
Certificate:
Data:
Version: 3 (0x2)
Serial Number:
1e:a1:44:88:27:3d:5c:c8:ff:ef:06:2e:da:21:05:29:30:a5:ce:2c
Signature Algorithm: sha256WithRSAEncryption
Issuer: CN = localhost
Validity
Not Before: Oct 11 07:36:01 2022 GMT
Not After : Oct 8 07:36:01 2032 GMT
Subject: CN = localhost
Subject Public Key Info:
Public Key Algorithm: rsaEncryption
RSA Public-Key: (2048 bit)
Modulus:
00:d5:b0:29:38:63:13:5e:1e:1d:ae:1f:47:88:b4:
44:96:21:d8:d7:03:a3:d8:f9:03:2f:4e:79:66:e6:
db:19:55:1d:85:9b:f1:78:2d:87:f3:72:91:13:dc:
ff:00:cb:ab:fd:a1:c8:3a:56:26:e3:88:1d:ec:98:
4a:af:eb:f9:60:80:27:e1:06:ba:c0:0d:c3:09:0e:
fe:d8:86:1e:25:b4:04:62:a5:75:46:8e:11:e8:61:
59:aa:97:17:ea:c7:4c:c6:13:8c:6d:54:2a:b9:78:
86:54:a9:6f:d6:31:96:c6:41:76:a3:c7:67:40:6f:
f2:1a:4c:0d:77:05:bb:3d:0b:16:f8:c7:de:6c:de:
7b:2e:b6:29:85:4b:a8:36:d3:f2:84:75:e0:85:17:
ce:22:84:4b:94:02:17:8a:36:2b:13:ee:2f:aa:55:
6b:ff:8b:df:d3:e0:23:8d:fd:c3:f8:e2:c8:a7:d5:
76:a6:73:7d:a8:5f:6a:49:02:78:a2:c5:66:14:ee:
86:50:3b:d1:67:7f:1b:0c:27:0d:84:ec:44:0d:39:
08:ba:69:65:e0:35:a4:67:aa:19:e7:fe:0e:4b:9f:
23:1e:4e:38:ed:d7:93:57:6e:94:31:05:d3:ae:f7:
6c:01:3c:30:69:19:f4:7b:b5:48:95:71:c9:9c:30:
43:9d
Exponent: 65537 (0x10001)
X509v3 extensions:
X509v3 Subject Key Identifier:
8E:81:0B:60:B1:F9:7D:D8:64:91:BB:30:86:E5:3D:CD:B7:82:D8:31
X509v3 Authority Key Identifier:
keyid:8E:81:0B:60:B1:F9:7D:D8:64:91:BB:30:86:E5:3D:CD:B7:82:D8:31
X509v3 Basic Constraints: critical
CA:TRUE
Signature Algorithm: sha256WithRSAEncryption
6c:ae:ee:3e:e3:d4:5d:29:37:62:b0:32:ce:a4:36:c7:25:b4:
6a:9f:ba:b4:f0:2f:0a:96:2f:dc:6d:df:7d:92:e7:f0:ee:f7:
de:44:9d:52:36:ff:0c:98:ef:8b:7f:27:df:6e:fe:64:11:7c:
01:5d:7f:c8:73:a3:24:24:ba:81:fd:a8:ae:28:4f:93:bb:92:
ff:86:d6:48:a2:ca:a5:1f:ea:1c:0d:02:22:e8:71:23:27:22:
4f:0f:37:58:9a:d9:fd:70:c5:4c:93:7d:47:1c:b6:ea:1b:4f:
4e:7c:eb:9d:9a:d3:28:78:67:27:e9:b1:ea:f6:93:68:76:e5:
2e:52:c6:29:91:ba:0a:96:2e:14:33:69:35:d7:b5:e0:c0:ef:
05:77:09:9b:a1:cc:7b:b2:f0:6a:cb:5c:5f:a1:27:69:b0:2c:
6e:93:eb:37:98:cd:97:8d:9e:78:a8:f5:99:12:66:86:48:cf:
b2:e0:68:6f:77:98:06:13:24:55:d1:c3:80:1d:59:53:1f:44:
85:bc:5d:29:aa:2a:a1:06:17:6b:e7:2b:11:0b:fd:e3:f8:88:
89:32:57:a3:70:f7:1b:6c:c1:66:c7:3c:a4:2d:e8:5f:00:1c:
55:2f:72:ed:d4:3a:3f:d0:95:de:6c:a4:96:6e:b4:63:0e:80:
08:b2:25:d5
-----BEGIN CERTIFICATE-----
MIIDCTCCAfGgAwIBAgIUHqFEiCc9XMj/7wYu2iEFKTClziwwDQYJKoZIhvcNAQEL
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMTAxMTA3MzYwMVoXDTMyMTAw
ODA3MzYwMVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
AAOCAQ8AMIIBCgKCAQEA1bApOGMTXh4drh9HiLREliHY1wOj2PkDL055ZubbGVUd
hZvxeC2H83KRE9z/AMur/aHIOlYm44gd7JhKr+v5YIAn4Qa6wA3DCQ7+2IYeJbQE
YqV1Ro4R6GFZqpcX6sdMxhOMbVQquXiGVKlv1jGWxkF2o8dnQG/yGkwNdwW7PQsW
+MfebN57LrYphUuoNtPyhHXghRfOIoRLlAIXijYrE+4vqlVr/4vf0+Ajjf3D+OLI
p9V2pnN9qF9qSQJ4osVmFO6GUDvRZ38bDCcNhOxEDTkIumll4DWkZ6oZ5/4OS58j
Hk447deTV26UMQXTrvdsATwwaRn0e7VIlXHJnDBDnQIDAQABo1MwUTAdBgNVHQ4E
FgQUjoELYLH5fdhkkbswhuU9zbeC2DEwHwYDVR0jBBgwFoAUjoELYLH5fdhkkbsw
huU9zbeC2DEwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbK7u
PuPUXSk3YrAyzqQ2xyW0ap+6tPAvCpYv3G3ffZLn8O733kSdUjb/DJjvi38n327+
ZBF8AV1/yHOjJCS6gf2orihPk7uS/4bWSKLKpR/qHA0CIuhxIyciTw83WJrZ/XDF
TJN9Rxy26htPTnzrnZrTKHhnJ+mx6vaTaHblLlLGKZG6CpYuFDNpNde14MDvBXcJ
m6HMe7LwastcX6EnabAsbpPrN5jNl42eeKj1mRJmhkjPsuBob3eYBhMkVdHDgB1Z
Ux9EhbxdKaoqoQYXa+crEQv94/iIiTJXo3D3G2zBZsc8pC3oXwAcVS9y7dQ6P9CV
3myklm60Yw6ACLIl1Q==
-----END CERTIFICATE-----

View File

@@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDVsCk4YxNeHh2u
H0eItESWIdjXA6PY+QMvTnlm5tsZVR2Fm/F4LYfzcpET3P8Ay6v9ocg6VibjiB3s
mEqv6/lggCfhBrrADcMJDv7Yhh4ltARipXVGjhHoYVmqlxfqx0zGE4xtVCq5eIZU
qW/WMZbGQXajx2dAb/IaTA13Bbs9Cxb4x95s3nsutimFS6g20/KEdeCFF84ihEuU
AheKNisT7i+qVWv/i9/T4CON/cP44sin1Xamc32oX2pJAniixWYU7oZQO9FnfxsM
Jw2E7EQNOQi6aWXgNaRnqhnn/g5LnyMeTjjt15NXbpQxBdOu92wBPDBpGfR7tUiV
ccmcMEOdAgMBAAECggEBAMMCIJv0zpf1o+Bja0S2PmFEQj72c3Buzxk85E2kIA7e
PjLQPW0PICJrSzp1U8HGHQ85tSCHvrWmYqin0oD5OHt4eOxC1+qspHB/3tJ6ksiV
n+rmVEAvJuiK7ulfOdRoTQf2jxC23saj1vMsLYOrfY0v8LVGJFQJ1UdqYF9eO6FX
8i6eQekV0n8u+DMUysYXfePDXEwpunKrlZwZtThgBY31gAIOdNo/FOAFe1yBJdPl
rUFZes1IrE0c4CNxodajuRNCjtNWoX8TK1cXQVUpPprdFLBcYG2P9mPZ7SkZWJc7
rkyPX6Wkb7q3laUCBxuKL1iOJIwaVBYaKfv4HS7VuYECgYEA9H7VB8+whWx2cTFb
9oYbcaU3HtbKRh6KQP8eB4IWeKV/c/ceWVAxtU9Hx2QU1zZ2fLl+KkaOGeECNNqD
BP1O5qk2qmkjJcP4kzh1K+p7zkqAkrhHqB36y/gwptB8v7JbCchQq9cnBeYsXNIa
j13KvteprRSnanKu18d2aC43cNMCgYEA3746ITtqy1g6AQ0Q/MXN/axsXixKfVjf
kgN/lpjy6oeoEIWKqiNrOQpwy4NeBo6ZN+cwjUUr9SY/BKsZqMGErO8Xuu+QtJYD
ioW/My9rTrTElbpsLpSvZDLc9IRepV4k+5PpXTIRBqp7Q3BZnTjbRMc8x/owG23G
eXnfVKlWM88CgYEA5HBQuMCrzK3/qFkW9Kpun+tfKfhD++nzATGcrCU2u7jd8cr1
1zsfhqkxhrIS6tYfNP/XSsarZLCgcCOuAQ5wFwIJaoVbaqDE80Dv8X1f+eoQYYW+
peyE9OjLBEGOHUoW13gLL9ORyWg7EOraGBPpKBC2n1nJ5qKKjF/4WPS9pjMCgYEA
3UuUyxGtivn0RN3bk2dBWkmT1YERG/EvD4gORbF5caZDADRU9fqaLoy5C1EfSnT3
7mbnipKD67CsW72vX04oH7NLUUVpZnOJhRTMC6A3Dl2UolMEdP3yi7QS/nV99ymq
gnnFMrw2QtWTnRweRnbZyKkW4OP/eOGWkMeNsHrcG9kCgYEAz/09cKumk349AIXV
g6Jw64gCTjWh157wnD3ZSPPEcr/09/fZwf1W0gkY/tbCVrVPJHWb3K5t2nRXjLlz
HMnQXmcMxMlY3Ufvm2H3ov1ODPKwpcBWUZqnpFTZX7rC58lO/wvgiKpgtHA3pDdw
oYDaaozVP4EnnByxhmHaM7ce07U=
-----END PRIVATE KEY-----