mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-04 05:50:38 +00:00
TLS for postgres_backend and proxy
Add TLS support to `postgres_backend`. Implement this support in `proxy`. Other applications must opt-in and provide a `rustls::ServerConfig`.
This commit is contained in:
47
Cargo.lock
generated
47
Cargo.lock
generated
@@ -1443,6 +1443,7 @@ dependencies = [
|
||||
"hex",
|
||||
"md5",
|
||||
"rand",
|
||||
"rustls",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
@@ -1679,6 +1680,29 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.19.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7"
|
||||
dependencies = [
|
||||
"base64 0.13.0",
|
||||
"log",
|
||||
"ring",
|
||||
"sct",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-split"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9d63f11490b4d8d45a362e171d7fe4a9ef154770a339e696a05eb354bc36837"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.5"
|
||||
@@ -1716,6 +1740,16 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.3.1"
|
||||
@@ -2463,6 +2497,16 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki"
|
||||
version = "0.21.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "3.1.1"
|
||||
@@ -2595,10 +2639,13 @@ dependencies = [
|
||||
"postgres",
|
||||
"rand",
|
||||
"routerify",
|
||||
"rustls",
|
||||
"rustls-split",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"webpki",
|
||||
"workspace_hack",
|
||||
"zenith_metrics",
|
||||
]
|
||||
|
||||
@@ -176,7 +176,7 @@ fn page_service_conn_main(
|
||||
}
|
||||
|
||||
let mut conn_handler = PageServerHandler::new(conf, auth);
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type)?;
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
|
||||
pgbackend.run(&mut conn_handler)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,5 +17,6 @@ serde_json = "1"
|
||||
tokio = { version = "1.7.1", features = ["full"] }
|
||||
tokio-postgres = "0.7.2"
|
||||
clap = "2.33.0"
|
||||
rustls = "0.19.1"
|
||||
|
||||
zenith_utils = { path = "../zenith_utils" }
|
||||
|
||||
@@ -8,13 +8,15 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::{mpsc, Mutex},
|
||||
sync::{mpsc, Arc, Mutex},
|
||||
thread,
|
||||
};
|
||||
|
||||
use clap::{App, Arg};
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
use clap::{App, Arg, ArgMatches};
|
||||
|
||||
use cplane_api::DatabaseInfo;
|
||||
use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig};
|
||||
|
||||
mod cplane_api;
|
||||
mod mgmt;
|
||||
@@ -33,6 +35,8 @@ pub struct ProxyConf {
|
||||
|
||||
/// control plane address where we would check auth.
|
||||
pub cplane_address: SocketAddr,
|
||||
|
||||
pub ssl_config: Option<Arc<ServerConfig>>,
|
||||
}
|
||||
|
||||
pub struct ProxyState {
|
||||
@@ -40,6 +44,38 @@ pub struct ProxyState {
|
||||
pub waiters: Mutex<HashMap<String, mpsc::Sender<anyhow::Result<DatabaseInfo>>>>,
|
||||
}
|
||||
|
||||
fn configure_ssl(arg_matches: &ArgMatches) -> anyhow::Result<Option<Arc<ServerConfig>>> {
|
||||
let (key_path, cert_path) = match (
|
||||
arg_matches.value_of("ssl-key"),
|
||||
arg_matches.value_of("ssl-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => (key_path, cert_path),
|
||||
(None, None) => return Ok(None),
|
||||
_ => bail!("either both or neither ssl-key and ssl-cert must be specified"),
|
||||
};
|
||||
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("SSL key file")?;
|
||||
let mut keys = pemfile::rsa_private_keys(&mut &key_bytes[..])
|
||||
.or_else(|_| pemfile::pkcs8_private_keys(&mut &key_bytes[..]))
|
||||
.map_err(|_| anyhow!("couldn't read TLS keys"))?;
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
keys.pop().unwrap()
|
||||
};
|
||||
|
||||
let cert_chain = {
|
||||
let cert_chain_bytes = std::fs::read(cert_path).context("SSL cert file")?;
|
||||
pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.map_err(|_| anyhow!("couldn't read TLS certificates"))?
|
||||
};
|
||||
|
||||
let mut config = ServerConfig::new(NoClientAuth::new());
|
||||
config.set_single_cert(cert_chain, key)?;
|
||||
config.versions = vec![ProtocolVersion::TLSv1_3];
|
||||
|
||||
Ok(Some(Arc::new(config)))
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let arg_matches = App::new("Zenith proxy/router")
|
||||
.arg(
|
||||
@@ -66,6 +102,20 @@ fn main() -> anyhow::Result<()> {
|
||||
.help("redirect unauthenticated users to given uri")
|
||||
.default_value("http://localhost:3000/psql_session/"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("ssl-key")
|
||||
.short("k")
|
||||
.long("ssl-key")
|
||||
.takes_value(true)
|
||||
.help("path to SSL key for client postgres connections"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("ssl-cert")
|
||||
.short("c")
|
||||
.long("ssl-cert")
|
||||
.takes_value(true)
|
||||
.help("path to SSL cert for client postgres connections"),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let conf = ProxyConf {
|
||||
@@ -73,6 +123,7 @@ fn main() -> anyhow::Result<()> {
|
||||
mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?,
|
||||
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
cplane_address: "127.0.0.1:3000".parse()?,
|
||||
ssl_config: configure_ssl(&arg_matches)?,
|
||||
};
|
||||
let state = ProxyState {
|
||||
conf,
|
||||
|
||||
@@ -34,7 +34,7 @@ pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow:
|
||||
|
||||
pub fn mgmt_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let mut conn_handler = MgmtHandler { state };
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
|
||||
pgbackend.run(&mut conn_handler)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ use anyhow::bail;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
use rand::Rng;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::thread;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use std::io::Write;
|
||||
use std::{io, sync::mpsc::channel, thread};
|
||||
use zenith_utils::postgres_backend::Stream;
|
||||
use zenith_utils::postgres_backend::{PostgresBackend, ProtoState};
|
||||
use zenith_utils::pq_proto::*;
|
||||
use zenith_utils::sock_split::{ReadStream, WriteStream};
|
||||
use zenith_utils::{postgres_backend, pq_proto::BeMessage};
|
||||
|
||||
///
|
||||
@@ -59,7 +60,11 @@ pub fn proxy_conn_main(
|
||||
cplane: CPlaneApi::new(&state.conf.cplane_address),
|
||||
user: "".into(),
|
||||
database: "".into(),
|
||||
pgb: PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?,
|
||||
pgb: PostgresBackend::new(
|
||||
socket,
|
||||
postgres_backend::AuthType::MD5,
|
||||
state.conf.ssl_config.clone(),
|
||||
)?,
|
||||
md5_salt: [0u8; 4],
|
||||
psql_session_id: "".into(),
|
||||
};
|
||||
@@ -75,17 +80,7 @@ pub fn proxy_conn_main(
|
||||
conn.handle_new_user()?
|
||||
};
|
||||
|
||||
// ok, proxy pass user connection to database_uri
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let _ = runtime.block_on(proxy_pass(conn.pgb, db_info))?;
|
||||
|
||||
println!("proxy_conn_main done;");
|
||||
|
||||
Ok(())
|
||||
proxy_pass(conn.pgb, db_info)
|
||||
}
|
||||
|
||||
impl ProxyConnection {
|
||||
@@ -94,6 +89,7 @@ impl ProxyConnection {
|
||||
}
|
||||
|
||||
fn handle_startup(&mut self) -> anyhow::Result<()> {
|
||||
let mut encrypted = false;
|
||||
loop {
|
||||
let msg = self.pgb.read_message()?;
|
||||
println!("got message {:?}", msg);
|
||||
@@ -102,11 +98,29 @@ impl ProxyConnection {
|
||||
println!("got startup message {:?}", m);
|
||||
|
||||
match m.kind {
|
||||
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
|
||||
StartupRequestCode::NegotiateGss => {
|
||||
self.pgb
|
||||
.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
StartupRequestCode::NegotiateSsl => {
|
||||
println!("SSL requested");
|
||||
self.pgb.write_message(&BeMessage::Negotiate)?;
|
||||
if self.pgb.tls_config.is_some() {
|
||||
self.pgb
|
||||
.write_message(&BeMessage::EncryptionResponse(true))?;
|
||||
self.pgb.start_tls()?;
|
||||
encrypted = true;
|
||||
} else {
|
||||
self.pgb
|
||||
.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
if self.state.conf.ssl_config.is_some() && !encrypted {
|
||||
self.pgb.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS".to_string(),
|
||||
))?;
|
||||
bail!("client did not connect with TLS");
|
||||
}
|
||||
self.user = m
|
||||
.params
|
||||
.get("user")
|
||||
@@ -226,31 +240,52 @@ databases without opening the browser.
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> {
|
||||
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
|
||||
async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<tokio::net::TcpStream> {
|
||||
let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()).await?;
|
||||
let config = db_info.conn_string().parse::<tokio_postgres::Config>()?;
|
||||
let _ = config.connect_raw(&mut socket, NoTls).await?;
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
println!("Connected to pg, proxying");
|
||||
/// Concurrently proxy both directions of the client and server connections
|
||||
fn proxy(
|
||||
client_read: ReadStream,
|
||||
client_write: WriteStream,
|
||||
server_read: ReadStream,
|
||||
server_write: WriteStream,
|
||||
) -> anyhow::Result<()> {
|
||||
fn do_proxy(mut reader: ReadStream, mut writer: WriteStream) -> io::Result<()> {
|
||||
std::io::copy(&mut reader, &mut writer)?;
|
||||
writer.flush()?;
|
||||
writer.shutdown(std::net::Shutdown::Both)
|
||||
}
|
||||
|
||||
let incoming_std = pgb.into_stream();
|
||||
incoming_std.set_nonblocking(true)?;
|
||||
let mut incoming_conn = tokio::net::TcpStream::from_std(incoming_std)?;
|
||||
let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write));
|
||||
|
||||
let (mut ri, mut wi) = incoming_conn.split();
|
||||
let (mut ro, mut wo) = socket.split();
|
||||
|
||||
let client_to_server = async {
|
||||
tokio::io::copy(&mut ri, &mut wo).await?;
|
||||
wo.shutdown().await
|
||||
};
|
||||
|
||||
let server_to_client = async {
|
||||
tokio::io::copy(&mut ro, &mut wi).await?;
|
||||
wi.shutdown().await
|
||||
};
|
||||
|
||||
tokio::try_join!(client_to_server, server_to_client)?;
|
||||
let res1 = do_proxy(server_read, client_write);
|
||||
let res2 = client_to_server_jh.join().unwrap();
|
||||
res1?;
|
||||
res2?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Proxy a client connection to a postgres database
|
||||
fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread().build()?;
|
||||
let db_stream = runtime.block_on(connect_to_db(db_info))?;
|
||||
let db_stream = db_stream.into_std()?;
|
||||
db_stream.set_nonblocking(false)?;
|
||||
|
||||
let db_stream = zenith_utils::sock_split::BidiStream::from_tcp(db_stream);
|
||||
let (db_read, db_write) = db_stream.split();
|
||||
|
||||
let stream = match pgb.into_stream() {
|
||||
Stream::Bidirectional(bidi_stream) => bidi_stream,
|
||||
_ => bail!("invalid stream"),
|
||||
};
|
||||
|
||||
let (client_read, client_write) = stream.split();
|
||||
proxy(client_read, client_write, db_read, db_write)
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ fn request_callback(conf: WalAcceptorConf, timelineid: ZTimelineId, tenantid: ZT
|
||||
|
||||
impl<'pg> ReceiveWalConn<'pg> {
|
||||
pub fn new(pg: &'pg mut PostgresBackend) -> Result<ReceiveWalConn<'pg>> {
|
||||
let peer_addr = pg.get_peer_addr()?;
|
||||
let peer_addr = pg.get_peer_addr().clone();
|
||||
Ok(ReceiveWalConn {
|
||||
pg_backend: pg,
|
||||
peer_addr,
|
||||
|
||||
@@ -11,8 +11,7 @@ use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read, Seek, SeekFrom};
|
||||
use std::net::TcpStream;
|
||||
use std::io::{Read, Seek, SeekFrom};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::thread::sleep;
|
||||
@@ -22,6 +21,7 @@ use zenith_utils::bin_ser::BeSer;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::pq_proto::{BeMessage, FeMessage, XLogDataBody};
|
||||
use zenith_utils::sock_split::ReadStream;
|
||||
|
||||
pub const END_REPLICATION_MARKER: Lsn = Lsn::MAX;
|
||||
|
||||
@@ -49,7 +49,7 @@ impl HotStandbyFeedback {
|
||||
pub struct ReplicationConn {
|
||||
/// This is an `Option` because we will spawn a background thread that will
|
||||
/// `take` it from us.
|
||||
stream_in: Option<BufReader<TcpStream>>,
|
||||
stream_in: Option<ReadStream>,
|
||||
}
|
||||
|
||||
// TODO: move this to crate::timeline when there's more users
|
||||
|
||||
@@ -41,7 +41,7 @@ fn handle_socket(socket: TcpStream, conf: WalAcceptorConf) -> Result<()> {
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
let mut conn_handler = SendWalHandler::new(conf);
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
|
||||
// libpq replication protocol between wal_acceptor and replicas/pagers
|
||||
pgbackend.run(&mut conn_handler)?;
|
||||
|
||||
|
||||
@@ -25,6 +25,10 @@ rand = "0.8.3"
|
||||
jsonwebtoken = "7"
|
||||
hex = { version = "0.4.3", features = ["serde"] }
|
||||
|
||||
rustls = "0.19.1"
|
||||
rustls-split = "0.2.1"
|
||||
|
||||
[dev-dependencies]
|
||||
hex-literal = "0.3"
|
||||
bytes = "1.0"
|
||||
webpki = "0.21"
|
||||
|
||||
@@ -23,3 +23,6 @@ pub mod auth;
|
||||
pub mod zid;
|
||||
// http endpoint utils
|
||||
pub mod http;
|
||||
|
||||
// socket splitting utils
|
||||
pub mod sock_split;
|
||||
|
||||
@@ -4,14 +4,16 @@
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode};
|
||||
use anyhow::{bail, ensure, Result};
|
||||
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
|
||||
use anyhow::{anyhow, bail, ensure, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use log::*;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::{self, BufReader, Write};
|
||||
use std::io::{self, Write};
|
||||
use std::net::{Shutdown, SocketAddr, TcpStream};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
@@ -45,6 +47,7 @@ pub trait Handler {
|
||||
#[derive(Clone, Copy, PartialEq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
Initialization,
|
||||
Encrypted,
|
||||
Authentication,
|
||||
Established,
|
||||
}
|
||||
@@ -76,12 +79,40 @@ pub enum ProcessMsgResult {
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Bidirectional(BidiStream),
|
||||
WriteOnly(WriteStream),
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
|
||||
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for Stream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
|
||||
Self::WriteOnly(write_stream) => write_stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
|
||||
Self::WriteOnly(write_stream) => write_stream.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
// replication.rs wants to handle reading on its own in separate thread, so
|
||||
// wrap in Option to be able to take and transfer the BufReader. Ugly, but I
|
||||
// have no better ideas.
|
||||
stream_in: Option<BufReader<TcpStream>>,
|
||||
stream_out: TcpStream,
|
||||
stream: Option<Stream>,
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
buf_out: BytesMut,
|
||||
|
||||
@@ -89,6 +120,9 @@ pub struct PostgresBackend {
|
||||
|
||||
md5_salt: [u8; 4],
|
||||
auth_type: AuthType,
|
||||
|
||||
peer_addr: SocketAddr,
|
||||
pub tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
}
|
||||
|
||||
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
@@ -102,47 +136,52 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
pub fn new(socket: TcpStream, auth_type: AuthType) -> io::Result<Self> {
|
||||
let mut pb = PostgresBackend {
|
||||
stream_in: None,
|
||||
stream_out: socket,
|
||||
pub fn new(
|
||||
socket: TcpStream,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
Ok(Self {
|
||||
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
state: ProtoState::Initialization,
|
||||
md5_salt: [0u8; 4],
|
||||
auth_type,
|
||||
};
|
||||
|
||||
// if socket cloning fails, report the error and bail out
|
||||
pb.stream_in = match pb.stream_out.try_clone() {
|
||||
Ok(read_sock) => Some(BufReader::new(read_sock)),
|
||||
Err(error) => {
|
||||
let errmsg = format!("{}", error);
|
||||
let _res = pb.write_message_noflush(&BeMessage::ErrorResponse(errmsg));
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(pb)
|
||||
tls_config,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_stream(self) -> TcpStream {
|
||||
self.stream_out
|
||||
pub fn into_stream(self) -> Stream {
|
||||
self.stream.unwrap()
|
||||
}
|
||||
|
||||
/// Get direct reference (into the Option) to the read stream.
|
||||
fn get_stream_in(&mut self) -> Result<&mut BufReader<TcpStream>> {
|
||||
match self.stream_in {
|
||||
Some(ref mut stream_in) => Ok(stream_in),
|
||||
None => bail!("stream_in was taken"),
|
||||
fn get_stream_in(&mut self) -> Result<&mut BidiStream> {
|
||||
match &mut self.stream {
|
||||
Some(Stream::Bidirectional(stream)) => Ok(stream),
|
||||
_ => Err(anyhow!("reader taken")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_peer_addr(&self) -> Result<SocketAddr> {
|
||||
Ok(self.stream_out.peer_addr()?)
|
||||
pub fn get_peer_addr(&self) -> &SocketAddr {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
pub fn take_stream_in(&mut self) -> Option<BufReader<TcpStream>> {
|
||||
self.stream_in.take()
|
||||
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
|
||||
let stream = self.stream.take();
|
||||
match stream {
|
||||
Some(Stream::Bidirectional(bidi_stream)) => {
|
||||
let (read, write) = bidi_stream.split();
|
||||
self.stream = Some(Stream::WriteOnly(write));
|
||||
Some(read)
|
||||
}
|
||||
stream => {
|
||||
self.stream = stream;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
@@ -151,7 +190,7 @@ impl PostgresBackend {
|
||||
|
||||
use ProtoState::*;
|
||||
match state {
|
||||
Initialization => FeStartupMessage::read(stream),
|
||||
Initialization | Encrypted => FeStartupMessage::read(stream),
|
||||
Authentication | Established => FeMessage::read(stream),
|
||||
}
|
||||
}
|
||||
@@ -164,7 +203,8 @@ impl PostgresBackend {
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.stream_out.write_all(&self.buf_out)?;
|
||||
let stream = self.stream.as_mut().unwrap();
|
||||
stream.write_all(&self.buf_out)?;
|
||||
self.buf_out.clear();
|
||||
Ok(self)
|
||||
}
|
||||
@@ -178,13 +218,14 @@ impl PostgresBackend {
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub fn run(mut self, handler: &mut impl Handler) -> Result<()> {
|
||||
let ret = self.run_message_loop(handler);
|
||||
let _res = self.stream_out.shutdown(Shutdown::Both);
|
||||
if let Some(stream) = self.stream.as_mut() {
|
||||
let _ = stream.shutdown(Shutdown::Both);
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<()> {
|
||||
let peer_addr = self.stream_out.peer_addr()?;
|
||||
trace!("postgres backend to {:?} started", peer_addr);
|
||||
trace!("postgres backend to {:?} started", self.peer_addr);
|
||||
|
||||
let mut unnamed_query_string = Bytes::new();
|
||||
|
||||
@@ -197,10 +238,24 @@ impl PostgresBackend {
|
||||
}
|
||||
}
|
||||
|
||||
trace!("postgres backend to {:?} exited", peer_addr);
|
||||
trace!("postgres backend to {:?} exited", self.peer_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
match self.stream.take() {
|
||||
Some(Stream::Bidirectional(bidi_stream)) => {
|
||||
let session = rustls::ServerSession::new(&self.tls_config.clone().unwrap());
|
||||
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(session)?));
|
||||
Ok(())
|
||||
}
|
||||
stream => {
|
||||
self.stream = stream;
|
||||
bail!("can't start TLs without bidi stream");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
@@ -224,11 +279,30 @@ impl PostgresBackend {
|
||||
trace!("got startup message {:?}", m);
|
||||
|
||||
match m.kind {
|
||||
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
|
||||
StartupRequestCode::NegotiateSsl => {
|
||||
info!("SSL requested");
|
||||
self.write_message(&BeMessage::Negotiate)?;
|
||||
|
||||
if self.tls_config.is_some() {
|
||||
self.write_message(&BeMessage::EncryptionResponse(true))?;
|
||||
self.start_tls()?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
} else {
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
}
|
||||
StartupRequestCode::NegotiateGss => {
|
||||
info!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
if self.tls_config.is_some() && !matches!(self.state, ProtoState::Encrypted)
|
||||
{
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS".to_string(),
|
||||
))?;
|
||||
bail!("client did not connect with TLS");
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &m)?;
|
||||
|
||||
@@ -348,8 +348,8 @@ pub enum BeMessage<'a> {
|
||||
// None means column is NULL
|
||||
DataRow(&'a [Option<&'a [u8]>]),
|
||||
ErrorResponse(String),
|
||||
// see https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.11
|
||||
Negotiate,
|
||||
// single byte - used in response to SSLRequest/GSSENCRequest
|
||||
EncryptionResponse(bool),
|
||||
NoData,
|
||||
ParameterDescription,
|
||||
ParameterStatus,
|
||||
@@ -657,8 +657,9 @@ impl<'a> BeMessage<'a> {
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
BeMessage::Negotiate => {
|
||||
buf.put_u8(b'N');
|
||||
BeMessage::EncryptionResponse(should_negotiate) => {
|
||||
let response = if *should_negotiate { b'Y' } else { b'N' };
|
||||
buf.put_u8(response);
|
||||
}
|
||||
|
||||
BeMessage::ParameterStatus => {
|
||||
|
||||
206
zenith_utils/src/sock_split.rs
Normal file
206
zenith_utils/src/sock_split.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use std::{
|
||||
io::{self, BufReader, Write},
|
||||
net::{Shutdown, TcpStream},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use rustls::Session;
|
||||
|
||||
/// Wrapper supporting reads of a shared TcpStream.
|
||||
pub struct ArcTcpRead(Arc<TcpStream>);
|
||||
|
||||
impl io::Read for ArcTcpRead {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
(&*self.0).read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ArcTcpRead {
|
||||
type Target = TcpStream;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.deref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around a TCP Stream supporting buffered reads.
|
||||
pub struct BufStream(BufReader<ArcTcpRead>);
|
||||
|
||||
impl io::Read for BufStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for BufStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.get_ref().write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.get_ref().flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl BufStream {
|
||||
/// Unwrap into the internal BufReader.
|
||||
fn into_reader(self) -> BufReader<ArcTcpRead> {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying TcpStream.
|
||||
fn get_ref(&self) -> &TcpStream {
|
||||
&*self.0.get_ref().0
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ReadStream {
|
||||
Tcp(BufReader<ArcTcpRead>),
|
||||
Tls(rustls_split::ReadHalf<rustls::ServerSession>),
|
||||
}
|
||||
|
||||
impl io::Read for ReadStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(reader) => reader.read(buf),
|
||||
Self::Tls(read_half) => read_half.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadStream {
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.get_ref().shutdown(how),
|
||||
Self::Tls(write_half) => write_half.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum WriteStream {
|
||||
Tcp(Arc<TcpStream>),
|
||||
Tls(rustls_split::WriteHalf<rustls::ServerSession>),
|
||||
}
|
||||
|
||||
impl WriteStream {
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.shutdown(how),
|
||||
Self::Tls(write_half) => write_half.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for WriteStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.as_ref().write(buf),
|
||||
Self::Tls(write_half) => write_half.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.as_ref().flush(),
|
||||
Self::Tls(write_half) => write_half.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum BidiStream {
|
||||
Tcp(BufStream),
|
||||
Tls {
|
||||
stream: BufStream,
|
||||
session: rustls::ServerSession,
|
||||
},
|
||||
}
|
||||
|
||||
impl BidiStream {
|
||||
pub fn from_tcp(stream: TcpStream) -> Self {
|
||||
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
|
||||
}
|
||||
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.get_ref().shutdown(how),
|
||||
Self::Tls {
|
||||
stream: reader,
|
||||
session,
|
||||
} => {
|
||||
if how == Shutdown::Read {
|
||||
reader.get_ref().shutdown(how)
|
||||
} else {
|
||||
session.send_close_notify();
|
||||
let mut stream = rustls::Stream::new(session, reader);
|
||||
let res = stream.flush();
|
||||
reader.get_ref().shutdown(how)?;
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Split the bi-directional stream into two owned read and write halves.
|
||||
pub fn split(self) -> (ReadStream, WriteStream) {
|
||||
match self {
|
||||
Self::Tcp(stream) => {
|
||||
let reader = stream.into_reader();
|
||||
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
|
||||
|
||||
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
|
||||
}
|
||||
Self::Tls { stream, session } => {
|
||||
let reader = stream.into_reader();
|
||||
let buffer_data = reader.buffer().to_owned();
|
||||
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
|
||||
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
|
||||
|
||||
// TODO would be nice to avoid the Arc here
|
||||
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
|
||||
|
||||
let (read_half, write_half) =
|
||||
rustls_split::split(socket, session, read_buf_cfg, write_buf_cfg);
|
||||
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_tls(self, mut session: rustls::ServerSession) -> io::Result<Self> {
|
||||
match self {
|
||||
Self::Tcp(mut stream) => {
|
||||
session.complete_io(&mut stream)?;
|
||||
assert!(!session.is_handshaking());
|
||||
Ok(Self::Tls { stream, session })
|
||||
}
|
||||
Self::Tls { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"TLS is already started on this stream",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Read for BidiStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.read(buf),
|
||||
Self::Tls { stream, session } => rustls::Stream::new(session, stream).read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for BidiStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.write(buf),
|
||||
Self::Tls { stream, session } => rustls::Stream::new(session, stream).write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.flush(),
|
||||
Self::Tls { stream, session } => rustls::Stream::new(session, stream).flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
21
zenith_utils/tests/cert.pem
Normal file
21
zenith_utils/tests/cert.pem
Normal file
@@ -0,0 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDbjCCAlagAwIBAgIUGHJukXa1bQathgBHC40+A18BsnYwDQYJKoZIhvcNAQEL
|
||||
BQAwYzELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcM
|
||||
DVNhbiBGcmFuY2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxv
|
||||
Y2FsaG9zdDAgFw0yMTA4MTMxODQyMjBaGA8yMTIxMDcyMDE4NDIyMFowYzELMAkG
|
||||
A1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNhbiBGcmFu
|
||||
Y2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxvY2FsaG9zdDCC
|
||||
ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOI9S+nh8ABMp5jpb7WWfAYr
|
||||
tGJ4C7gi9IPTVIRxSSrt5KglEysrOiKlhan1Ut2e8CCudztdXtCvT8/goJWlmxpF
|
||||
IQkErlCsOdGHeEJ0EZxoU1fMkBAQVf6Rb1JE9ladG2+D1e7yvxmMqfPVuU8lj+kN
|
||||
nESP+I3ESNCtuqgtfcErxu3TuhSzV2slSi5lrYQCwERgCevl6LUNd2mEaYdS4mmJ
|
||||
4RZqc2C4y7JO5wSDjga8GIBHJVo70HRVsvX7eE8r6tMP2HyGyonBitBKAc2QEQIv
|
||||
cLCuMOTtTBlYcMvTmJEOHFKwIJXm0XmQfAWeKFfyK7493fB4Gu+8Dc1xC+IHaTEC
|
||||
AwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IB
|
||||
AQBjY+g3eF8m8lEWz+QgKp88MhTdtJTsEsSz0GAi58SnEkuyxVOHjKEyjGKJWTtT
|
||||
ICgmEzC85uaS7VBdftoYNmsbvNewGiisDGQRWCjOGM7lTaA4FQPADguexMvXh/nO
|
||||
9PQoTxtp7qwvGWO2mED6LWU6bjT3cL+XgrOwT9sticRTl6/BXV8wAmyxT0DkQ3nJ
|
||||
zbRuTP/G2kE0bRK++67kK0ovopRkX6Dl6di1EFlkAnPBC2d8tdcNTXYhkxZk4O0q
|
||||
GUolwiuWz/dtD3tZ2bx3vqzT7uIFHS4XP6Q3SRNWFTGhuvAc7DPvCZBqxy6odeyQ
|
||||
VxBgJtq+pNjYYkeaSQVQ+UMU
|
||||
-----END CERTIFICATE-----
|
||||
27
zenith_utils/tests/key.pem
Normal file
27
zenith_utils/tests/key.pem
Normal file
@@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEA4j1L6eHwAEynmOlvtZZ8Biu0YngLuCL0g9NUhHFJKu3kqCUT
|
||||
Kys6IqWFqfVS3Z7wIK53O11e0K9Pz+CglaWbGkUhCQSuUKw50Yd4QnQRnGhTV8yQ
|
||||
EBBV/pFvUkT2Vp0bb4PV7vK/GYyp89W5TyWP6Q2cRI/4jcRI0K26qC19wSvG7dO6
|
||||
FLNXayVKLmWthALARGAJ6+XotQ13aYRph1LiaYnhFmpzYLjLsk7nBIOOBrwYgEcl
|
||||
WjvQdFWy9ft4Tyvq0w/YfIbKicGK0EoBzZARAi9wsK4w5O1MGVhwy9OYkQ4cUrAg
|
||||
lebReZB8BZ4oV/Irvj3d8Hga77wNzXEL4gdpMQIDAQABAoIBAQClKycO+zpinZQG
|
||||
GPbLVa/6OVIaSZYUusBUtaaQgrxuMPusnlSeQZLR1JH/APGchvq8gWLe3k3ogPT9
|
||||
yPq0BhF0Xl+928L/dp1HkWWE7oQk8i1Wfiv27lY54iepoltN5KkxAsjfCC3oEz/I
|
||||
mpINbFjiRmN90rYdmd2nLA6H1Z5ntZQm5AcTo3OJZlTVN9eH9TV8f0AQRQgUJsL9
|
||||
75agSmj7euqZOqvvwfpsYzaZEhzMSG2QIcS3WglInbHy8c6ikZSm36J36wgsatMz
|
||||
CBZ6pMNtonRSKvAECQhBGEA73evtnGbLH0EY9KouN4KSHEHob89dGVeeXozksf9x
|
||||
QUE1/yOhAoGBAP818f7vIH6Z3QwWgTMwQsPBW+wNOIbTZrbZaihnz2K9XMu39TV6
|
||||
DWQHMsOlvg2QURZGwqB3jFn4wqZHmt7XYwk553E60kIw4hDvgpkkqmXVwK3kZASQ
|
||||
RRUax3hZ1gCWxpXlRZ1SvHNXjN9KEFwqQbR33XcxzC3TpSp0KYghT9jFAoGBAOLw
|
||||
agejqSF+f/5W1QhEKlM+tSlluo2sn5kKVkM4nNezFukb3pu5oScFjoQQGsoaz5aU
|
||||
kLlxW5h/aSxquhgcuo6I4Ux5dcgNm4QeonCCp+Qycn7tzyoJFL4odT9vYPQa5O9E
|
||||
hD9aSqhBBD1IIOS2T3vcW6VxibKZx1CRMDdRz119AoGALflr1L8DHYteNLVBJRWG
|
||||
kXkdtBJVooQmtr3Hz+uTgngWZWSIOc/45ZIeZPxQlmTvFpI8sWeX0wVrG0U+8vHe
|
||||
F2Vk+hLcmavwrZhX8HqYb6vn/+tq0R+kMj8Wu+mDEawXrh0VQ1gKNsUIzZisBc5e
|
||||
88G8FaLU41SDJniymqFVnvkCgYEA1ou/UfWRwg6b5tIkmKoI8aZJExgPpDzcrYyu
|
||||
POLatLmlIUCt1b9K8V85evTWvtdWBd/yar8WfzeFMO69fGo8nOAfT3NMvJLQwblM
|
||||
jN2Y6A4hXIpq3iyzpYsOPaiImn6KjQHTnSk5h5Pf9CeqoU8SGeEb629JZMYpPqvk
|
||||
T4hSaOkCgYBPaf51oSAstqdj0vxrsFS3EN3D8Fk0xQWt9Ss3ZGFAlTaEq5xoIk4k
|
||||
YfKVDv1S6/vlzbheIIzQ2lzVvG4AW+drQLsmEx5iMKvbNtFAur9kwUFU202Q2dki
|
||||
ZQJ/JvjnPYFKxy+SVlLJ1h9RD9E3dgL/Ai7OUfbmX771vN0IQF7Z6Q==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
224
zenith_utils/tests/ssl_test.rs
Normal file
224
zenith_utils/tests/ssl_test.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::{Cursor, Read, Write},
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use lazy_static::lazy_static;
|
||||
use rustls::Session;
|
||||
|
||||
use zenith_utils::postgres_backend::{AuthType, Handler, PostgresBackend};
|
||||
|
||||
fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let client_stream = TcpStream::connect(addr).unwrap();
|
||||
let (server_stream, _) = listener.accept().unwrap();
|
||||
(server_stream, client_stream)
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref KEY: rustls::PrivateKey = {
|
||||
let mut cursor = Cursor::new(include_bytes!("key.pem"));
|
||||
rustls::internal::pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()
|
||||
};
|
||||
static ref CERT: rustls::Certificate = {
|
||||
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
|
||||
rustls::internal::pemfile::certs(&mut cursor).unwrap()[0].clone()
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
const QUERY: &[u8] = b"hello world";
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
// SSLRequest
|
||||
client_sock.write_u32::<BigEndian>(8).unwrap();
|
||||
client_sock.write_u32::<BigEndian>(80877103).unwrap();
|
||||
|
||||
let ssl_response = client_sock.read_u8().unwrap();
|
||||
assert_eq!(b'Y', ssl_response);
|
||||
|
||||
let mut cfg = rustls::ClientConfig::new();
|
||||
cfg.root_store.add(&CERT).unwrap();
|
||||
let client_config = Arc::new(cfg);
|
||||
|
||||
let dns_name = webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap();
|
||||
let mut session = rustls::ClientSession::new(&client_config, dns_name);
|
||||
|
||||
session.complete_io(&mut client_sock).unwrap();
|
||||
assert!(!session.is_handshaking());
|
||||
|
||||
let mut stream = rustls::Stream::new(&mut session, &mut client_sock);
|
||||
|
||||
// StartupMessage
|
||||
stream.write_u32::<BigEndian>(9).unwrap();
|
||||
stream.write_u32::<BigEndian>(196608).unwrap();
|
||||
stream.write_u8(0).unwrap();
|
||||
stream.flush().unwrap();
|
||||
|
||||
// wait for ReadyForQuery
|
||||
let mut msg_buf = Vec::new();
|
||||
loop {
|
||||
let msg = stream.read_u8().unwrap();
|
||||
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
|
||||
msg_buf.resize(size as usize, 0);
|
||||
stream.read_exact(&mut msg_buf).unwrap();
|
||||
|
||||
if msg == b'Z' {
|
||||
// ReadyForQuery
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Query
|
||||
stream.write_u8(b'Q').unwrap();
|
||||
stream
|
||||
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
|
||||
.unwrap();
|
||||
stream.write_all(QUERY).unwrap();
|
||||
stream.flush().unwrap();
|
||||
|
||||
// ReadyForQuery
|
||||
let msg = stream.read_u8().unwrap();
|
||||
assert_eq!(msg, b'Z');
|
||||
});
|
||||
|
||||
struct TestHandler {
|
||||
got_query: bool,
|
||||
}
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
query_string: bytes::Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
self.got_query = query_string.as_ref() == QUERY;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
let mut handler = TestHandler { got_query: false };
|
||||
|
||||
let mut cfg = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
cfg.set_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(cfg));
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config).unwrap();
|
||||
pgb.run(&mut handler).unwrap();
|
||||
assert!(handler.got_query);
|
||||
|
||||
client_jh.join().unwrap();
|
||||
|
||||
// TODO consider shutdown behavior
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// SSLRequest
|
||||
buf.put_u32(8);
|
||||
buf.put_u32(80877103);
|
||||
client_sock.write_all(&buf).unwrap();
|
||||
buf.clear();
|
||||
|
||||
let ssl_response = client_sock.read_u8().unwrap();
|
||||
assert_eq!(b'N', ssl_response);
|
||||
});
|
||||
|
||||
struct TestHandler;
|
||||
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: bytes::Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
let mut handler = TestHandler;
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None).unwrap();
|
||||
pgb.run(&mut handler).unwrap();
|
||||
|
||||
client_jh.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_forces_ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
// StartupMessage
|
||||
client_sock.write_u32::<BigEndian>(9).unwrap();
|
||||
client_sock.write_u32::<BigEndian>(196608).unwrap();
|
||||
client_sock.write_u8(0).unwrap();
|
||||
client_sock.flush().unwrap();
|
||||
|
||||
// ErrorResponse
|
||||
assert_eq!(client_sock.read_u8().unwrap(), b'E');
|
||||
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
|
||||
|
||||
let mut body = vec![0; len as usize];
|
||||
client_sock.read_exact(&mut body).unwrap();
|
||||
let mut body = Bytes::from(body);
|
||||
|
||||
let mut errors = HashMap::new();
|
||||
loop {
|
||||
let field_type = body.get_u8();
|
||||
if field_type == 0u8 {
|
||||
break;
|
||||
}
|
||||
|
||||
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
|
||||
let mut value = body.split_to(end_idx + 1);
|
||||
assert_eq!(value[end_idx], 0u8);
|
||||
value.truncate(end_idx);
|
||||
let old = errors.insert(field_type, value);
|
||||
assert!(old.is_none());
|
||||
}
|
||||
|
||||
assert!(!body.has_remaining());
|
||||
|
||||
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
|
||||
|
||||
// TODO read failure
|
||||
});
|
||||
|
||||
struct TestHandler;
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: bytes::Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
let mut handler = TestHandler;
|
||||
|
||||
let mut cfg = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
cfg.set_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(cfg));
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config).unwrap();
|
||||
let res = pgb.run(&mut handler).unwrap_err();
|
||||
assert_eq!("client did not connect with TLS", format!("{}", res));
|
||||
|
||||
client_jh.join().unwrap();
|
||||
|
||||
// TODO consider shutdown behavior
|
||||
}
|
||||
Reference in New Issue
Block a user