md5 auth for postgres_backend.rs

This commit is contained in:
Stas Kelvich
2021-06-23 10:43:24 +03:00
parent d55095ab21
commit bf45bef284
10 changed files with 280 additions and 18 deletions

11
Cargo.lock generated
View File

@@ -1317,6 +1317,16 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "proxy"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"md5",
"zenith_utils",
]
[[package]]
name = "quick-xml"
version = "0.20.0"
@@ -2395,6 +2405,7 @@ dependencies = [
"hex-literal",
"log",
"postgres",
"rand",
"serde",
"thiserror",
"workspace_hack",

View File

@@ -7,6 +7,7 @@ members = [
"postgres_ffi",
"zenith_utils",
"workspace_hack",
"proxy"
]
[profile.release]

View File

@@ -19,8 +19,8 @@ use std::net::TcpListener;
use std::str::FromStr;
use std::thread;
use std::{io, net::TcpStream};
use zenith_utils::postgres_backend;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::postgres_backend::{self, AuthType};
use zenith_utils::pq_proto::{
BeMessage, FeMessage, RowDescriptor, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC,
};
@@ -154,7 +154,7 @@ pub fn thread_main(conf: &'static PageServerConf, listener: TcpListener) -> anyh
fn page_service_conn_main(conf: &'static PageServerConf, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = PageServerHandler::new(conf);
let mut pgbackend = PostgresBackend::new(socket)?;
let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
pgbackend.run(&mut conn_handler)
}

14
proxy/Cargo.toml Normal file
View File

@@ -0,0 +1,14 @@
[package]
name = "proxy"
version = "0.1.0"
authors = ["Stas Kelvich <stas.kelvich@gmail.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0"
bytes = { version = "1.0.1", features = ['serde'] }
md5 = "0.7.0"
zenith_utils = { path = "../zenith_utils" }

47
proxy/src/main.rs Normal file
View File

@@ -0,0 +1,47 @@
///
/// Postgres protocol proxy/router.
///
/// This service listens psql port and can check auth via external service
/// (control plane API in our case) and can create new databases and accounts
/// in somewhat transparent manner (again via communication with control plane API).
///
use std::{
net::{SocketAddr, TcpListener},
thread,
};
mod cplane_api;
mod proxy;
pub struct ProxyConf {
/// main entrypoint for users to connect to
pub proxy_address: SocketAddr,
/// http management endpoint. Upon user account creation control plane
/// will notify us here, so that we can 'unfreeze' user session.
pub mgmt_address: SocketAddr,
/// control plane address where we check auth and create clusters.
pub cplane_address: SocketAddr,
}
fn main() -> anyhow::Result<()> {
let conf = ProxyConf {
proxy_address: "0.0.0.0:4000".parse()?,
mgmt_address: "0.0.0.0:8080".parse()?,
cplane_address: "127.0.0.1:3000".parse()?,
};
let conf: &'static ProxyConf = Box::leak(Box::new(conf));
// Check that we can bind to address before further initialization
println!("Starting proxy on {}", conf.proxy_address);
let pageserver_listener = TcpListener::bind(conf.proxy_address)?;
// Spawn a thread to listen for connections. It will spawn further threads
// for each connection.
let proxy_listener_thread = thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || proxy::thread_main(&conf, pageserver_listener))?;
proxy_listener_thread.join().unwrap()
}

98
proxy/src/proxy.rs Normal file
View File

@@ -0,0 +1,98 @@
use crate::ProxyConf;
use anyhow::bail;
use bytes::Bytes;
use std::{
net::{TcpListener, TcpStream},
thread,
};
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::{
postgres_backend,
pq_proto::{BeMessage, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC},
};
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> {
loop {
let (socket, peer_addr) = listener.accept()?;
println!("accepted connection from {}", peer_addr);
socket.set_nodelay(true).unwrap();
thread::spawn(move || {
if let Err(err) = proxy_conn_main(conf, socket) {
println!("error: {}", err);
}
});
}
}
pub fn proxy_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = ProxyHandler { conf };
let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?;
pgbackend.run(&mut conn_handler)
}
struct ProxyHandler {
conf: &'static ProxyConf,
}
// impl ProxyHandler {
// }
impl postgres_backend::Handler for ProxyHandler {
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: Bytes,
) -> anyhow::Result<()> {
println!("Got query: {:?}", query_string);
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&HELLO_WORLD_ROW)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.flush()?;
Ok(())
}
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
sm: &zenith_utils::pq_proto::FeStartupMessage,
) -> anyhow::Result<()> {
println!("Got startup: {:?}", sm);
Ok(())
}
fn check_auth_md5(
&mut self,
pgb: &mut PostgresBackend,
md5_response: &[u8],
) -> anyhow::Result<()> {
let user = "stask";
let pass = "mypassword";
let stored_hash = format!(
"{:x}",
md5::compute([pass.as_bytes(), user.as_bytes()].concat())
);
let salted_stored_hash = format!(
"md5{:x}",
md5::compute([stored_hash.as_bytes(), &pgb.md5_salt].concat())
);
let received_hash = std::str::from_utf8(&md5_response)?;
println!(
"check_auth_md5: {:?} vs {}, salt {:?}",
received_hash, salted_stored_hash, &pgb.md5_salt
);
if received_hash == salted_stored_hash {
Ok(())
} else {
bail!("Auth failed")
}
}
}

View File

@@ -11,7 +11,7 @@ use std::thread;
use crate::receive_wal::ReceiveWalConn;
use crate::send_wal::SendWalHandler;
use crate::WalAcceptorConf;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::postgres_backend::{AuthType, PostgresBackend};
/// Accept incoming TCP connections and spawn them into a background thread.
pub fn thread_main(conf: WalAcceptorConf) -> Result<()> {
@@ -50,7 +50,7 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> {
ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor
} else {
let mut conn_handler = SendWalHandler::new(conf);
let mut pgbackend = PostgresBackend::new(socket)?;
let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
// libpq replication protocol between wal_acceptor and replicas/pagers
pgbackend.run(&mut conn_handler)?;
}

View File

@@ -14,6 +14,7 @@ bincode = "1.3"
thiserror = "1.0"
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
workspace_hack = { path = "../workspace_hack" }
rand = "0.8.3"
[dev-dependencies]
hex-literal = "0.3"

View File

@@ -8,6 +8,7 @@ use anyhow::bail;
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use log::*;
use rand::Rng;
use std::io;
use std::io::{BufReader, Write};
use std::net::{Shutdown, TcpStream};
@@ -18,10 +19,29 @@ pub trait Handler {
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()>;
/// Called on startup packet receival, allows to process params.
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> {
Ok(())
}
/// Check auth
fn check_auth_md5(&mut self, _pgb: &mut PostgresBackend, _md5_response: &[u8]) -> Result<()> {
bail!("Auth failed")
}
}
#[derive(PartialEq)]
enum ProtoState {
Initialization,
Authentication,
Established,
}
#[derive(PartialEq)]
pub enum AuthType {
Trust,
MD5,
}
pub struct PostgresBackend {
@@ -32,7 +52,10 @@ pub struct PostgresBackend {
stream_out: TcpStream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
init_done: bool,
state: ProtoState,
pub md5_salt: [u8; 4],
auth_type: AuthType,
}
// In replication.rs a separate thread is reading keepalives from the
@@ -45,12 +68,14 @@ impl Drop for PostgresBackend {
}
impl PostgresBackend {
pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
pub fn new(socket: TcpStream, auth_type: AuthType) -> Result<Self, std::io::Error> {
let mut pb = PostgresBackend {
stream_in: None,
stream_out: socket,
buf_out: BytesMut::with_capacity(10 * 1024),
init_done: false,
state: ProtoState::Initialization,
md5_salt: [0u8; 4],
auth_type,
};
// if socket cloning fails, report the error and bail out
pb.stream_in = match pb.stream_out.try_clone() {
@@ -78,10 +103,11 @@ impl PostgresBackend {
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
if !self.init_done {
FeStartupMessage::read(self.get_stream_in()?)
} else {
FeMessage::read(self.get_stream_in()?)
match self.state {
ProtoState::Initialization => FeStartupMessage::read(self.get_stream_in()?),
ProtoState::Authentication | ProtoState::Established => {
FeMessage::read(self.get_stream_in()?)
}
}
}
@@ -112,6 +138,21 @@ impl PostgresBackend {
loop {
let msg = self.read_message()?;
trace!("got message {:?}", msg);
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state == ProtoState::Authentication || self.state == ProtoState::Initialization
{
match msg {
Some(FeMessage::PasswordMessage(ref _m)) => {}
Some(FeMessage::StartupMessage(ref _m)) => {}
Some(_) => {
bail!("protocol violation");
}
None => {}
};
}
match msg {
Some(FeMessage::StartupMessage(m)) => {
trace!("got startup message {:?}", m);
@@ -124,16 +165,52 @@ impl PostgresBackend {
self.write_message(&BeMessage::Negotiate)?;
}
StartupRequestCode::Normal => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.init_done = true;
if self.auth_type == AuthType::Trust {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
} else {
rand::thread_rng().fill(&mut self.md5_salt);
let md5_salt = self.md5_salt.clone();
self.write_message(&BeMessage::AuthenticationMD5Password(
&md5_salt,
))?;
self.state = ProtoState::Authentication;
}
}
StartupRequestCode::Cancel => break,
}
}
Some(FeMessage::PasswordMessage(m)) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
let (trailing_null, md5_response) = m.split_last().unwrap();
if *trailing_null != 0 {
let errmsg = "protocol violation";
self.write_message(&BeMessage::ErrorResponse(format!("{}", errmsg)))?;
bail!("auth failed: {}", errmsg);
}
if let Err(e) = handler.check_auth_md5(self, md5_response) {
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
bail!("auth failed: {}", e);
} else {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
}
Some(FeMessage::Query(m)) => {
trace!("got query {:?}", m.body);
// xxx distinguish fatal and recoverable errors?

View File

@@ -28,6 +28,7 @@ pub enum FeMessage {
Terminate,
CopyData(Bytes),
CopyDone,
PasswordMessage(Bytes),
}
#[derive(Debug)]
@@ -115,6 +116,7 @@ impl FeMessage {
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)),
}
}
@@ -307,6 +309,7 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
#[derive(Debug)]
pub enum BeMessage<'a> {
AuthenticationOk,
AuthenticationMD5Password(&'a [u8; 4]),
BindComplete,
CommandComplete(&'a [u8]),
ControlFile,
@@ -442,7 +445,17 @@ impl<'a> BeMessage<'a> {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(0);
buf.put_i32(0); // Specifies that the authentication was successful.
Ok::<_, io::Error>(())
})
.unwrap(); // write into BytesMut can't fail
}
BeMessage::AuthenticationMD5Password(salt) => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
buf.put_slice(&salt[..]);
Ok::<_, io::Error>(())
})
.unwrap(); // write into BytesMut can't fail