mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 17:02:56 +00:00
md5 auth for postgres_backend.rs
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -1317,6 +1317,16 @@ dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proxy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"md5",
|
||||
"zenith_utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-xml"
|
||||
version = "0.20.0"
|
||||
@@ -2395,6 +2405,7 @@ dependencies = [
|
||||
"hex-literal",
|
||||
"log",
|
||||
"postgres",
|
||||
"rand",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"workspace_hack",
|
||||
|
||||
@@ -7,6 +7,7 @@ members = [
|
||||
"postgres_ffi",
|
||||
"zenith_utils",
|
||||
"workspace_hack",
|
||||
"proxy"
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
|
||||
@@ -19,8 +19,8 @@ use std::net::TcpListener;
|
||||
use std::str::FromStr;
|
||||
use std::thread;
|
||||
use std::{io, net::TcpStream};
|
||||
use zenith_utils::postgres_backend;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::postgres_backend::{self, AuthType};
|
||||
use zenith_utils::pq_proto::{
|
||||
BeMessage, FeMessage, RowDescriptor, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC,
|
||||
};
|
||||
@@ -154,7 +154,7 @@ pub fn thread_main(conf: &'static PageServerConf, listener: TcpListener) -> anyh
|
||||
|
||||
fn page_service_conn_main(conf: &'static PageServerConf, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let mut conn_handler = PageServerHandler::new(conf);
|
||||
let mut pgbackend = PostgresBackend::new(socket)?;
|
||||
let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
|
||||
pgbackend.run(&mut conn_handler)
|
||||
}
|
||||
|
||||
|
||||
14
proxy/Cargo.toml
Normal file
14
proxy/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "proxy"
|
||||
version = "0.1.0"
|
||||
authors = ["Stas Kelvich <stas.kelvich@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
bytes = { version = "1.0.1", features = ['serde'] }
|
||||
md5 = "0.7.0"
|
||||
|
||||
zenith_utils = { path = "../zenith_utils" }
|
||||
47
proxy/src/main.rs
Normal file
47
proxy/src/main.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
///
|
||||
/// Postgres protocol proxy/router.
|
||||
///
|
||||
/// This service listens psql port and can check auth via external service
|
||||
/// (control plane API in our case) and can create new databases and accounts
|
||||
/// in somewhat transparent manner (again via communication with control plane API).
|
||||
///
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
thread,
|
||||
};
|
||||
|
||||
mod cplane_api;
|
||||
mod proxy;
|
||||
|
||||
pub struct ProxyConf {
|
||||
/// main entrypoint for users to connect to
|
||||
pub proxy_address: SocketAddr,
|
||||
|
||||
/// http management endpoint. Upon user account creation control plane
|
||||
/// will notify us here, so that we can 'unfreeze' user session.
|
||||
pub mgmt_address: SocketAddr,
|
||||
|
||||
/// control plane address where we check auth and create clusters.
|
||||
pub cplane_address: SocketAddr,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let conf = ProxyConf {
|
||||
proxy_address: "0.0.0.0:4000".parse()?,
|
||||
mgmt_address: "0.0.0.0:8080".parse()?,
|
||||
cplane_address: "127.0.0.1:3000".parse()?,
|
||||
};
|
||||
let conf: &'static ProxyConf = Box::leak(Box::new(conf));
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
println!("Starting proxy on {}", conf.proxy_address);
|
||||
let pageserver_listener = TcpListener::bind(conf.proxy_address)?;
|
||||
|
||||
// Spawn a thread to listen for connections. It will spawn further threads
|
||||
// for each connection.
|
||||
let proxy_listener_thread = thread::Builder::new()
|
||||
.name("Proxy thread".into())
|
||||
.spawn(move || proxy::thread_main(&conf, pageserver_listener))?;
|
||||
|
||||
proxy_listener_thread.join().unwrap()
|
||||
}
|
||||
98
proxy/src/proxy.rs
Normal file
98
proxy/src/proxy.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use crate::ProxyConf;
|
||||
use anyhow::bail;
|
||||
use bytes::Bytes;
|
||||
use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
thread,
|
||||
};
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::{
|
||||
postgres_backend,
|
||||
pq_proto::{BeMessage, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC},
|
||||
};
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept()?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
socket.set_nodelay(true).unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = proxy_conn_main(conf, socket) {
|
||||
println!("error: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn proxy_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let mut conn_handler = ProxyHandler { conf };
|
||||
let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?;
|
||||
pgbackend.run(&mut conn_handler)
|
||||
}
|
||||
|
||||
struct ProxyHandler {
|
||||
conf: &'static ProxyConf,
|
||||
}
|
||||
|
||||
// impl ProxyHandler {
|
||||
// }
|
||||
|
||||
impl postgres_backend::Handler for ProxyHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
println!("Got query: {:?}", query_string);
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&HELLO_WORLD_ROW)?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
sm: &zenith_utils::pq_proto::FeStartupMessage,
|
||||
) -> anyhow::Result<()> {
|
||||
println!("Got startup: {:?}", sm);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_auth_md5(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
md5_response: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
let user = "stask";
|
||||
let pass = "mypassword";
|
||||
let stored_hash = format!(
|
||||
"{:x}",
|
||||
md5::compute([pass.as_bytes(), user.as_bytes()].concat())
|
||||
);
|
||||
let salted_stored_hash = format!(
|
||||
"md5{:x}",
|
||||
md5::compute([stored_hash.as_bytes(), &pgb.md5_salt].concat())
|
||||
);
|
||||
|
||||
let received_hash = std::str::from_utf8(&md5_response)?;
|
||||
|
||||
println!(
|
||||
"check_auth_md5: {:?} vs {}, salt {:?}",
|
||||
received_hash, salted_stored_hash, &pgb.md5_salt
|
||||
);
|
||||
|
||||
if received_hash == salted_stored_hash {
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("Auth failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ use std::thread;
|
||||
use crate::receive_wal::ReceiveWalConn;
|
||||
use crate::send_wal::SendWalHandler;
|
||||
use crate::WalAcceptorConf;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::postgres_backend::{AuthType, PostgresBackend};
|
||||
|
||||
/// Accept incoming TCP connections and spawn them into a background thread.
|
||||
pub fn thread_main(conf: WalAcceptorConf) -> Result<()> {
|
||||
@@ -50,7 +50,7 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> {
|
||||
ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor
|
||||
} else {
|
||||
let mut conn_handler = SendWalHandler::new(conf);
|
||||
let mut pgbackend = PostgresBackend::new(socket)?;
|
||||
let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
|
||||
// libpq replication protocol between wal_acceptor and replicas/pagers
|
||||
pgbackend.run(&mut conn_handler)?;
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ bincode = "1.3"
|
||||
thiserror = "1.0"
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
workspace_hack = { path = "../workspace_hack" }
|
||||
rand = "0.8.3"
|
||||
|
||||
[dev-dependencies]
|
||||
hex-literal = "0.3"
|
||||
|
||||
@@ -8,6 +8,7 @@ use anyhow::bail;
|
||||
use anyhow::Result;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use log::*;
|
||||
use rand::Rng;
|
||||
use std::io;
|
||||
use std::io::{BufReader, Write};
|
||||
use std::net::{Shutdown, TcpStream};
|
||||
@@ -18,10 +19,29 @@ pub trait Handler {
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()>;
|
||||
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check auth
|
||||
fn check_auth_md5(&mut self, _pgb: &mut PostgresBackend, _md5_response: &[u8]) -> Result<()> {
|
||||
bail!("Auth failed")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum ProtoState {
|
||||
Initialization,
|
||||
Authentication,
|
||||
Established,
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
MD5,
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
@@ -32,7 +52,10 @@ pub struct PostgresBackend {
|
||||
stream_out: TcpStream,
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
buf_out: BytesMut,
|
||||
init_done: bool,
|
||||
|
||||
state: ProtoState,
|
||||
pub md5_salt: [u8; 4],
|
||||
auth_type: AuthType,
|
||||
}
|
||||
|
||||
// In replication.rs a separate thread is reading keepalives from the
|
||||
@@ -45,12 +68,14 @@ impl Drop for PostgresBackend {
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
|
||||
pub fn new(socket: TcpStream, auth_type: AuthType) -> Result<Self, std::io::Error> {
|
||||
let mut pb = PostgresBackend {
|
||||
stream_in: None,
|
||||
stream_out: socket,
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
init_done: false,
|
||||
state: ProtoState::Initialization,
|
||||
md5_salt: [0u8; 4],
|
||||
auth_type,
|
||||
};
|
||||
// if socket cloning fails, report the error and bail out
|
||||
pb.stream_in = match pb.stream_out.try_clone() {
|
||||
@@ -78,10 +103,11 @@ impl PostgresBackend {
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
|
||||
if !self.init_done {
|
||||
FeStartupMessage::read(self.get_stream_in()?)
|
||||
} else {
|
||||
FeMessage::read(self.get_stream_in()?)
|
||||
match self.state {
|
||||
ProtoState::Initialization => FeStartupMessage::read(self.get_stream_in()?),
|
||||
ProtoState::Authentication | ProtoState::Established => {
|
||||
FeMessage::read(self.get_stream_in()?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,6 +138,21 @@ impl PostgresBackend {
|
||||
loop {
|
||||
let msg = self.read_message()?;
|
||||
trace!("got message {:?}", msg);
|
||||
|
||||
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
|
||||
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
|
||||
if self.state == ProtoState::Authentication || self.state == ProtoState::Initialization
|
||||
{
|
||||
match msg {
|
||||
Some(FeMessage::PasswordMessage(ref _m)) => {}
|
||||
Some(FeMessage::StartupMessage(ref _m)) => {}
|
||||
Some(_) => {
|
||||
bail!("protocol violation");
|
||||
}
|
||||
None => {}
|
||||
};
|
||||
}
|
||||
|
||||
match msg {
|
||||
Some(FeMessage::StartupMessage(m)) => {
|
||||
trace!("got startup message {:?}", m);
|
||||
@@ -124,16 +165,52 @@ impl PostgresBackend {
|
||||
self.write_message(&BeMessage::Negotiate)?;
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
|
||||
// psycopg2 will not connect if client_encoding is not
|
||||
// specified by the server
|
||||
self.write_message_noflush(&BeMessage::ParameterStatus)?;
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.init_done = true;
|
||||
if self.auth_type == AuthType::Trust {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
|
||||
// psycopg2 will not connect if client_encoding is not
|
||||
// specified by the server
|
||||
self.write_message_noflush(&BeMessage::ParameterStatus)?;
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
} else {
|
||||
rand::thread_rng().fill(&mut self.md5_salt);
|
||||
let md5_salt = self.md5_salt.clone();
|
||||
self.write_message(&BeMessage::AuthenticationMD5Password(
|
||||
&md5_salt,
|
||||
))?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
StartupRequestCode::Cancel => break,
|
||||
}
|
||||
}
|
||||
|
||||
Some(FeMessage::PasswordMessage(m)) => {
|
||||
trace!("got password message '{:?}'", m);
|
||||
|
||||
assert!(self.state == ProtoState::Authentication);
|
||||
|
||||
let (trailing_null, md5_response) = m.split_last().unwrap();
|
||||
|
||||
if *trailing_null != 0 {
|
||||
let errmsg = "protocol violation";
|
||||
self.write_message(&BeMessage::ErrorResponse(format!("{}", errmsg)))?;
|
||||
bail!("auth failed: {}", errmsg);
|
||||
}
|
||||
|
||||
if let Err(e) = handler.check_auth_md5(self, md5_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
|
||||
bail!("auth failed: {}", e);
|
||||
} else {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
|
||||
// psycopg2 will not connect if client_encoding is not
|
||||
// specified by the server
|
||||
self.write_message_noflush(&BeMessage::ParameterStatus)?;
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
}
|
||||
|
||||
Some(FeMessage::Query(m)) => {
|
||||
trace!("got query {:?}", m.body);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
|
||||
@@ -28,6 +28,7 @@ pub enum FeMessage {
|
||||
Terminate,
|
||||
CopyData(Bytes),
|
||||
CopyDone,
|
||||
PasswordMessage(Bytes),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -115,6 +116,7 @@ impl FeMessage {
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)),
|
||||
}
|
||||
}
|
||||
@@ -307,6 +309,7 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
#[derive(Debug)]
|
||||
pub enum BeMessage<'a> {
|
||||
AuthenticationOk,
|
||||
AuthenticationMD5Password(&'a [u8; 4]),
|
||||
BindComplete,
|
||||
CommandComplete(&'a [u8]),
|
||||
ControlFile,
|
||||
@@ -442,7 +445,17 @@ impl<'a> BeMessage<'a> {
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.put_u8(b'R');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i32(0);
|
||||
buf.put_i32(0); // Specifies that the authentication was successful.
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap(); // write into BytesMut can't fail
|
||||
}
|
||||
|
||||
BeMessage::AuthenticationMD5Password(salt) => {
|
||||
buf.put_u8(b'R');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
|
||||
buf.put_slice(&salt[..]);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap(); // write into BytesMut can't fail
|
||||
|
||||
Reference in New Issue
Block a user