From bf45bef2846ecc53d14b1f634039eca88de2c9e7 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Wed, 23 Jun 2021 10:43:24 +0300 Subject: [PATCH] md5 auth for postgres_backend.rs --- Cargo.lock | 11 +++ Cargo.toml | 1 + pageserver/src/page_service.rs | 4 +- proxy/Cargo.toml | 14 ++++ proxy/src/main.rs | 47 ++++++++++++ proxy/src/proxy.rs | 98 +++++++++++++++++++++++++ walkeeper/src/wal_service.rs | 4 +- zenith_utils/Cargo.toml | 1 + zenith_utils/src/postgres_backend.rs | 103 +++++++++++++++++++++++---- zenith_utils/src/pq_proto.rs | 15 +++- 10 files changed, 280 insertions(+), 18 deletions(-) create mode 100644 proxy/Cargo.toml create mode 100644 proxy/src/main.rs create mode 100644 proxy/src/proxy.rs diff --git a/Cargo.lock b/Cargo.lock index efd5e0c23c..c2248c227b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1317,6 +1317,16 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "proxy" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytes", + "md5", + "zenith_utils", +] + [[package]] name = "quick-xml" version = "0.20.0" @@ -2395,6 +2405,7 @@ dependencies = [ "hex-literal", "log", "postgres", + "rand", "serde", "thiserror", "workspace_hack", diff --git a/Cargo.toml b/Cargo.toml index a795cd5abd..d317e9a27d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "postgres_ffi", "zenith_utils", "workspace_hack", + "proxy" ] [profile.release] diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index dcea1c3de8..cffc875300 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -19,8 +19,8 @@ use std::net::TcpListener; use std::str::FromStr; use std::thread; use std::{io, net::TcpStream}; -use zenith_utils::postgres_backend; use zenith_utils::postgres_backend::PostgresBackend; +use zenith_utils::postgres_backend::{self, AuthType}; use zenith_utils::pq_proto::{ BeMessage, FeMessage, RowDescriptor, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC, }; @@ -154,7 +154,7 @@ pub fn thread_main(conf: &'static PageServerConf, listener: TcpListener) -> anyh fn page_service_conn_main(conf: &'static PageServerConf, socket: TcpStream) -> anyhow::Result<()> { let mut conn_handler = PageServerHandler::new(conf); - let mut pgbackend = PostgresBackend::new(socket)?; + let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; pgbackend.run(&mut conn_handler) } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml new file mode 100644 index 0000000000..cf9a69a895 --- /dev/null +++ b/proxy/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "proxy" +version = "0.1.0" +authors = ["Stas Kelvich "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0" +bytes = { version = "1.0.1", features = ['serde'] } +md5 = "0.7.0" + +zenith_utils = { path = "../zenith_utils" } diff --git a/proxy/src/main.rs b/proxy/src/main.rs new file mode 100644 index 0000000000..a511a7c8bd --- /dev/null +++ b/proxy/src/main.rs @@ -0,0 +1,47 @@ +/// +/// Postgres protocol proxy/router. +/// +/// This service listens psql port and can check auth via external service +/// (control plane API in our case) and can create new databases and accounts +/// in somewhat transparent manner (again via communication with control plane API). +/// +use std::{ + net::{SocketAddr, TcpListener}, + thread, +}; + +mod cplane_api; +mod proxy; + +pub struct ProxyConf { + /// main entrypoint for users to connect to + pub proxy_address: SocketAddr, + + /// http management endpoint. Upon user account creation control plane + /// will notify us here, so that we can 'unfreeze' user session. + pub mgmt_address: SocketAddr, + + /// control plane address where we check auth and create clusters. + pub cplane_address: SocketAddr, +} + +fn main() -> anyhow::Result<()> { + let conf = ProxyConf { + proxy_address: "0.0.0.0:4000".parse()?, + mgmt_address: "0.0.0.0:8080".parse()?, + cplane_address: "127.0.0.1:3000".parse()?, + }; + let conf: &'static ProxyConf = Box::leak(Box::new(conf)); + + // Check that we can bind to address before further initialization + println!("Starting proxy on {}", conf.proxy_address); + let pageserver_listener = TcpListener::bind(conf.proxy_address)?; + + // Spawn a thread to listen for connections. It will spawn further threads + // for each connection. + let proxy_listener_thread = thread::Builder::new() + .name("Proxy thread".into()) + .spawn(move || proxy::thread_main(&conf, pageserver_listener))?; + + proxy_listener_thread.join().unwrap() +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs new file mode 100644 index 0000000000..169503b987 --- /dev/null +++ b/proxy/src/proxy.rs @@ -0,0 +1,98 @@ +use crate::ProxyConf; +use anyhow::bail; +use bytes::Bytes; +use std::{ + net::{TcpListener, TcpStream}, + thread, +}; +use zenith_utils::postgres_backend::PostgresBackend; +use zenith_utils::{ + postgres_backend, + pq_proto::{BeMessage, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC}, +}; + +/// +/// Main proxy listener loop. +/// +/// Listens for connections, and launches a new handler thread for each. +/// +pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> { + loop { + let (socket, peer_addr) = listener.accept()?; + println!("accepted connection from {}", peer_addr); + socket.set_nodelay(true).unwrap(); + + thread::spawn(move || { + if let Err(err) = proxy_conn_main(conf, socket) { + println!("error: {}", err); + } + }); + } +} + +pub fn proxy_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> { + let mut conn_handler = ProxyHandler { conf }; + let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?; + pgbackend.run(&mut conn_handler) +} + +struct ProxyHandler { + conf: &'static ProxyConf, +} + +// impl ProxyHandler { +// } + +impl postgres_backend::Handler for ProxyHandler { + fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query_string: Bytes, + ) -> anyhow::Result<()> { + println!("Got query: {:?}", query_string); + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&HELLO_WORLD_ROW)? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + pgb.flush()?; + Ok(()) + } + + fn startup( + &mut self, + _pgb: &mut PostgresBackend, + sm: &zenith_utils::pq_proto::FeStartupMessage, + ) -> anyhow::Result<()> { + println!("Got startup: {:?}", sm); + Ok(()) + } + + fn check_auth_md5( + &mut self, + pgb: &mut PostgresBackend, + md5_response: &[u8], + ) -> anyhow::Result<()> { + let user = "stask"; + let pass = "mypassword"; + let stored_hash = format!( + "{:x}", + md5::compute([pass.as_bytes(), user.as_bytes()].concat()) + ); + let salted_stored_hash = format!( + "md5{:x}", + md5::compute([stored_hash.as_bytes(), &pgb.md5_salt].concat()) + ); + + let received_hash = std::str::from_utf8(&md5_response)?; + + println!( + "check_auth_md5: {:?} vs {}, salt {:?}", + received_hash, salted_stored_hash, &pgb.md5_salt + ); + + if received_hash == salted_stored_hash { + Ok(()) + } else { + bail!("Auth failed") + } + } +} diff --git a/walkeeper/src/wal_service.rs b/walkeeper/src/wal_service.rs index 82796e1298..124a273159 100644 --- a/walkeeper/src/wal_service.rs +++ b/walkeeper/src/wal_service.rs @@ -11,7 +11,7 @@ use std::thread; use crate::receive_wal::ReceiveWalConn; use crate::send_wal::SendWalHandler; use crate::WalAcceptorConf; -use zenith_utils::postgres_backend::PostgresBackend; +use zenith_utils::postgres_backend::{AuthType, PostgresBackend}; /// Accept incoming TCP connections and spawn them into a background thread. pub fn thread_main(conf: WalAcceptorConf) -> Result<()> { @@ -50,7 +50,7 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor } else { let mut conn_handler = SendWalHandler::new(conf); - let mut pgbackend = PostgresBackend::new(socket)?; + let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; // libpq replication protocol between wal_acceptor and replicas/pagers pgbackend.run(&mut conn_handler)?; } diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index bd42109bf2..3fe2e8b8df 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -14,6 +14,7 @@ bincode = "1.3" thiserror = "1.0" postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" } workspace_hack = { path = "../workspace_hack" } +rand = "0.8.3" [dev-dependencies] hex-literal = "0.3" diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 78993af761..7565ee75ba 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -8,6 +8,7 @@ use anyhow::bail; use anyhow::Result; use bytes::{Bytes, BytesMut}; use log::*; +use rand::Rng; use std::io; use std::io::{BufReader, Write}; use std::net::{Shutdown, TcpStream}; @@ -18,10 +19,29 @@ pub trait Handler { /// might be not what we want after CopyData streaming, but currently we don't /// care). fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()>; + /// Called on startup packet receival, allows to process params. fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> { Ok(()) } + + /// Check auth + fn check_auth_md5(&mut self, _pgb: &mut PostgresBackend, _md5_response: &[u8]) -> Result<()> { + bail!("Auth failed") + } +} + +#[derive(PartialEq)] +enum ProtoState { + Initialization, + Authentication, + Established, +} + +#[derive(PartialEq)] +pub enum AuthType { + Trust, + MD5, } pub struct PostgresBackend { @@ -32,7 +52,10 @@ pub struct PostgresBackend { stream_out: TcpStream, // Output buffer. c.f. BeMessage::write why we are using BytesMut here. buf_out: BytesMut, - init_done: bool, + + state: ProtoState, + pub md5_salt: [u8; 4], + auth_type: AuthType, } // In replication.rs a separate thread is reading keepalives from the @@ -45,12 +68,14 @@ impl Drop for PostgresBackend { } impl PostgresBackend { - pub fn new(socket: TcpStream) -> Result { + pub fn new(socket: TcpStream, auth_type: AuthType) -> Result { let mut pb = PostgresBackend { stream_in: None, stream_out: socket, buf_out: BytesMut::with_capacity(10 * 1024), - init_done: false, + state: ProtoState::Initialization, + md5_salt: [0u8; 4], + auth_type, }; // if socket cloning fails, report the error and bail out pb.stream_in = match pb.stream_out.try_clone() { @@ -78,10 +103,11 @@ impl PostgresBackend { /// Read full message or return None if connection is closed. pub fn read_message(&mut self) -> Result> { - if !self.init_done { - FeStartupMessage::read(self.get_stream_in()?) - } else { - FeMessage::read(self.get_stream_in()?) + match self.state { + ProtoState::Initialization => FeStartupMessage::read(self.get_stream_in()?), + ProtoState::Authentication | ProtoState::Established => { + FeMessage::read(self.get_stream_in()?) + } } } @@ -112,6 +138,21 @@ impl PostgresBackend { loop { let msg = self.read_message()?; trace!("got message {:?}", msg); + + // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth + // TODO: change that to proper top-level match of protocol state with separate message handling for each state + if self.state == ProtoState::Authentication || self.state == ProtoState::Initialization + { + match msg { + Some(FeMessage::PasswordMessage(ref _m)) => {} + Some(FeMessage::StartupMessage(ref _m)) => {} + Some(_) => { + bail!("protocol violation"); + } + None => {} + }; + } + match msg { Some(FeMessage::StartupMessage(m)) => { trace!("got startup message {:?}", m); @@ -124,16 +165,52 @@ impl PostgresBackend { self.write_message(&BeMessage::Negotiate)?; } StartupRequestCode::Normal => { - self.write_message_noflush(&BeMessage::AuthenticationOk)?; - // psycopg2 will not connect if client_encoding is not - // specified by the server - self.write_message_noflush(&BeMessage::ParameterStatus)?; - self.write_message(&BeMessage::ReadyForQuery)?; - self.init_done = true; + if self.auth_type == AuthType::Trust { + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + // psycopg2 will not connect if client_encoding is not + // specified by the server + self.write_message_noflush(&BeMessage::ParameterStatus)?; + self.write_message(&BeMessage::ReadyForQuery)?; + self.state = ProtoState::Established; + } else { + rand::thread_rng().fill(&mut self.md5_salt); + let md5_salt = self.md5_salt.clone(); + self.write_message(&BeMessage::AuthenticationMD5Password( + &md5_salt, + ))?; + self.state = ProtoState::Authentication; + } } StartupRequestCode::Cancel => break, } } + + Some(FeMessage::PasswordMessage(m)) => { + trace!("got password message '{:?}'", m); + + assert!(self.state == ProtoState::Authentication); + + let (trailing_null, md5_response) = m.split_last().unwrap(); + + if *trailing_null != 0 { + let errmsg = "protocol violation"; + self.write_message(&BeMessage::ErrorResponse(format!("{}", errmsg)))?; + bail!("auth failed: {}", errmsg); + } + + if let Err(e) = handler.check_auth_md5(self, md5_response) { + self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?; + bail!("auth failed: {}", e); + } else { + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + // psycopg2 will not connect if client_encoding is not + // specified by the server + self.write_message_noflush(&BeMessage::ParameterStatus)?; + self.write_message(&BeMessage::ReadyForQuery)?; + self.state = ProtoState::Established; + } + } + Some(FeMessage::Query(m)) => { trace!("got query {:?}", m.body); // xxx distinguish fatal and recoverable errors? diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index 5570e1b223..ac6e3bb971 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -28,6 +28,7 @@ pub enum FeMessage { Terminate, CopyData(Bytes), CopyDone, + PasswordMessage(Bytes), } #[derive(Debug)] @@ -115,6 +116,7 @@ impl FeMessage { b'X' => Ok(Some(FeMessage::Terminate)), b'd' => Ok(Some(FeMessage::CopyData(body))), b'c' => Ok(Some(FeMessage::CopyDone)), + b'p' => Ok(Some(FeMessage::PasswordMessage(body))), tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)), } } @@ -307,6 +309,7 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { #[derive(Debug)] pub enum BeMessage<'a> { AuthenticationOk, + AuthenticationMD5Password(&'a [u8; 4]), BindComplete, CommandComplete(&'a [u8]), ControlFile, @@ -442,7 +445,17 @@ impl<'a> BeMessage<'a> { BeMessage::AuthenticationOk => { buf.put_u8(b'R'); write_body(buf, |buf| { - buf.put_i32(0); + buf.put_i32(0); // Specifies that the authentication was successful. + Ok::<_, io::Error>(()) + }) + .unwrap(); // write into BytesMut can't fail + } + + BeMessage::AuthenticationMD5Password(salt) => { + buf.put_u8(b'R'); + write_body(buf, |buf| { + buf.put_i32(5); // Specifies that an MD5-encrypted password is required. + buf.put_slice(&salt[..]); Ok::<_, io::Error>(()) }) .unwrap(); // write into BytesMut can't fail