From 61a5b59224920a0aa6b7c90bef6d0957732a9702 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Mon, 22 Nov 2021 14:34:46 +0300 Subject: [PATCH] [WIP] [proxy] Add SCRAM auth --- Cargo.lock | 4 + proxy/Cargo.toml | 16 +- proxy/src/auth.rs | 139 ++++++++++++++++ proxy/src/main.rs | 4 + proxy/src/parse.rs | 18 +++ proxy/src/proxy.rs | 51 +++--- proxy/src/sasl.rs | 158 +++++++++++++++++++ proxy/src/scram.rs | 131 +++++++++++++++ proxy/src/scram/channel_binding.rs | 77 +++++++++ proxy/src/scram/key.rs | 40 +++++ proxy/src/scram/messages.rs | 228 +++++++++++++++++++++++++++ proxy/src/scram/secret.rs | 65 ++++++++ proxy/src/scram/signature.rs | 70 ++++++++ zenith_utils/src/postgres_backend.rs | 3 +- zenith_utils/src/pq_proto.rs | 36 ++++- 15 files changed, 1003 insertions(+), 37 deletions(-) create mode 100644 proxy/src/auth.rs create mode 100644 proxy/src/parse.rs create mode 100644 proxy/src/sasl.rs create mode 100644 proxy/src/scram.rs create mode 100644 proxy/src/scram/channel_binding.rs create mode 100644 proxy/src/scram/key.rs create mode 100644 proxy/src/scram/messages.rs create mode 100644 proxy/src/scram/secret.rs create mode 100644 proxy/src/scram/signature.rs diff --git a/Cargo.lock b/Cargo.lock index 73f037995c..117e102a35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1421,9 +1421,11 @@ name = "proxy" version = "0.1.0" dependencies = [ "anyhow", + "base64 0.13.0", "bytes", "clap", "hex", + "hmac 0.10.1", "lazy_static", "md5", "parking_lot", @@ -1432,6 +1434,8 @@ dependencies = [ "rustls 0.19.1", "serde", "serde_json", + "sha2", + "thiserror", "tokio", "tokio-postgres", "zenith_utils", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 42287b04bb..7845d714bf 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -8,18 +8,22 @@ edition = "2018" [dependencies] anyhow = "1.0" +base64 = "0.13.0" bytes = { version = "1.0.1", features = ['serde'] } +clap = "2.33.0" +hex = "0.4.3" +hmac = "0.10.1" lazy_static = "1.4.0" md5 = "0.7.0" -rand = "0.8.3" -hex = "0.4.3" parking_lot = "0.11.2" +rand = "0.8.3" +reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } +rustls = "0.19.1" serde = "1" serde_json = "1" -tokio = { version = "1.11", features = ["macros"] } +sha2 = "0.9.8" +thiserror = "1.0.30" +tokio = { version = "1.11", features = ['macros'] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" } -clap = "2.33.0" -rustls = "0.19.1" -reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } zenith_utils = { path = "../zenith_utils" } diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs new file mode 100644 index 0000000000..f5e7d8a6d1 --- /dev/null +++ b/proxy/src/auth.rs @@ -0,0 +1,139 @@ +//! Authentication machinery. + +use crate::sasl::{SaslFirstMessage, SaslMechanism, SaslMessage, SaslStream}; +use crate::scram::{ScramExchangeServer, ScramSecret}; +use anyhow::{bail, Context}; +use zenith_utils::postgres_backend::{PostgresBackend, ProtoState}; +use zenith_utils::pq_proto::{ + BeAuthenticationSaslMessage as BeSaslMessage, BeMessage as Be, FeMessage as Fe, *, +}; + +// TODO: add SCRAM-SHA-256-PLUS +/// A list of supported SCRAM methods. +const SCRAM_METHODS: &[&str] = &["SCRAM-SHA-256"]; + +/// Initial state of [`AuthStream`]. +pub struct Begin; + +/// Use [SCRAM](crate::scram)-based auth in [`AuthStream`]. +pub struct Scram<'a>(pub &'a ScramSecret); + +/// Use password-based auth in [`AuthStream`]. +pub struct Md5( + /// Salt for client. + pub [u8; 4], +); + +/// Every authentication selector is supposed to implement this trait. +pub trait AuthMethod { + /// Any authentication selector should provide initial backend message + /// containing auth method name and parameters, e.g. md5 salt. + fn first_message(&self) -> BeMessage<'_>; +} + +impl AuthMethod for Scram<'_> { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationSasl(BeSaslMessage::Methods(SCRAM_METHODS)) + } +} + +impl AuthMethod for Md5 { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationMD5Password(self.0) + } +} + +/// This wrapper for [`PostgresBackend`] performs client authentication. +#[must_use] +pub struct AuthStream<'a, State> { + /// The underlying stream which implements libpq's protocol. + pgb: &'a mut PostgresBackend, + /// State might contain ancillary data (see [`AuthStream::begin`]). + state: State, +} + +/// Initial state of the stream wrapper. +impl<'a> AuthStream<'a, Begin> { + /// Create a new wrapper for client authentication. + pub fn new(pgb: &'a mut PostgresBackend) -> Self { + Self { pgb, state: Begin } + } + + /// Move to the next step by sending auth method's name & params to client. + pub fn begin(self, method: M) -> anyhow::Result> { + self.pgb.write_message(&method.first_message())?; + self.pgb.state = ProtoState::Authentication; + + Ok(AuthStream { + pgb: self.pgb, + state: method, + }) + } +} + +/// Stream wrapper for handling simple MD5 password auth. +impl AuthStream<'_, Md5> { + /// Perform user authentication; Raise an error in case authentication failed. + pub fn authenticate(mut self) -> anyhow::Result<()> { + let msg = self.read_password_message()?; + let (_trailing_null, _md5_response) = msg.split_last().context("bad password message")?; + + Ok(()) + } +} + +/// Stream wrapper for handling [SCRAM](crate::scram) auth. +impl AuthStream<'_, Scram<'_>> { + /// Perform user authentication; Raise an error in case authentication failed. + pub fn authenticate(mut self) -> anyhow::Result<()> { + // Initial client message contains the chosen auth method's name + let msg = self.read_password_message()?; + let sasl = SaslFirstMessage::parse(&msg).context("bad SASL message")?; + + // Currently, the only supported SASL method is SCRAM + if !SCRAM_METHODS.contains(&sasl.method) { + bail!("unsupported SASL method: {}", sasl.method); + } + + let secret = self.state.0; + let stream = (Some(msg.slice_ref(sasl.message)), &mut self); + ScramExchangeServer::new(secret).authenticate(stream)?; + + Ok(()) + } +} + +/// Only [`AuthMethod`] states should receive password messages. +impl AuthStream<'_, M> { + /// Receive a new [`PasswordMessage`](FeMessage::PasswordMessage) and extract its payload. + fn read_password_message(&mut self) -> anyhow::Result { + match self.pgb.read_message()? { + Some(Fe::PasswordMessage(msg)) => Ok(msg), + None => bail!("connection is lost"), + bad => bail!("unexpected message type: {:?}", bad), + } + } +} + +/// Abstract away all intricacies of [`PostgresBackend`], +/// since [SASL](crate::sasl) protocols are text-based. +impl SaslStream for AuthStream<'_, Scram<'_>> { + type In = bytes::Bytes; + + fn recv(&mut self) -> anyhow::Result { + self.read_password_message() + } + + fn send(&mut self, data: &SaslMessage>) -> anyhow::Result<()> { + let reply = match data { + SaslMessage::Continue(reply) => BeSaslMessage::Continue(reply.as_ref()), + SaslMessage::Final(reply) => BeSaslMessage::Final(reply.as_ref()), + }; + + self.pgb.write_message(&Be::AuthenticationSasl(reply))?; + + Ok(()) + } +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 8b397c4444..6f26a82fbc 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -11,9 +11,13 @@ use state::{ProxyConfig, ProxyState}; use std::thread; use zenith_utils::{tcp_listener, GIT_VERSION}; +mod auth; mod cplane_api; mod mgmt; +mod parse; mod proxy; +mod sasl; +mod scram; mod state; mod waiters; diff --git a/proxy/src/parse.rs b/proxy/src/parse.rs new file mode 100644 index 0000000000..8a05ff9c82 --- /dev/null +++ b/proxy/src/parse.rs @@ -0,0 +1,18 @@ +//! Small parsing helpers. + +use std::convert::TryInto; +use std::ffi::CStr; + +pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { + let pos = bytes.iter().position(|&x| x == 0)?; + let (cstr, other) = bytes.split_at(pos + 1); + // SAFETY: we've already checked that there's a terminator + Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other)) +} + +pub fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { + (bytes.len() >= N).then(|| { + let (head, tail) = bytes.split_at(N); + (head.try_into().unwrap(), tail) + }) +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 26936159d0..f02f16eeed 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,3 +1,4 @@ +use crate::auth::{self, AuthStream}; use crate::cplane_api::{CPlaneApi, DatabaseInfo}; use crate::ProxyState; use anyhow::{anyhow, bail}; @@ -10,7 +11,7 @@ use std::collections::HashMap; use std::net::{SocketAddr, TcpStream}; use std::{io, thread}; use tokio_postgres::NoTls; -use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream}; +use zenith_utils::postgres_backend::{self, PostgresBackend, Stream}; use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *}; use zenith_utils::sock_split::{ReadStream, WriteStream}; @@ -211,40 +212,34 @@ impl ProxyConnection { } } - fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result { - let md5_salt = rand::random::<[u8; 4]>(); + fn handle_existing_user(&mut self, _user: &str, _db: &str) -> anyhow::Result { + let _cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters); - // Ask password - self.pgb - .write_message(&Be::AuthenticationMD5Password(&md5_salt))?; - self.pgb.state = ProtoState::Authentication; // XXX + // TODO: fetch secret from console + // user='user' password='password' + let secret = crate::scram::ScramSecret::parse( + &[ + "SCRAM-SHA-256", + "4096:XiWzgkfGNyY3ipsz08PY+A==", + &[ + "YMmirZHYtTB6erVDCxL4Zjn66Kn7RCfS+aV3qROV4o8=", + "aCSKHnugk1l9Ut6VhO5VeeWsB8xhVdPk/NyEgjOJ3nk=", + ] + .join(":"), + ] + .join("$"), + ) + .unwrap(); - // Check password - let msg = match self.pgb.read_message()? { - Some(Fe::PasswordMessage(msg)) => msg, - None => bail!("connection is lost"), - bad => bail!("unexpected message type: {:?}", bad), - }; - println!("got message: {:?}", msg); - - let (_trailing_null, md5_response) = msg - .split_last() - .ok_or_else(|| anyhow!("unexpected password message"))?; - - let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters); - let db_info = cplane.authenticate_proxy_request( - user, - db, - md5_response, - &md5_salt, - &self.psql_session_id, - )?; + AuthStream::new(&mut self.pgb) + .begin(auth::Scram(&secret))? + .authenticate()?; self.pgb .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&BeParameterStatusMessage::encoding())?; - Ok(db_info) + todo!() } fn handle_new_user(&mut self) -> anyhow::Result { diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs new file mode 100644 index 0000000000..7726942c6e --- /dev/null +++ b/proxy/src/sasl.rs @@ -0,0 +1,158 @@ +//! Simple Authentication and Security Layer. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +use crate::parse::{split_at_const, split_cstr}; +use anyhow::Context; +use thiserror::Error; + +/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage). +#[derive(Debug)] +pub struct SaslFirstMessage<'a> { + /// Authentication method, e.g. `"SCRAM-SHA-256"`. + pub method: &'a str, + /// Initial client message. + pub message: &'a [u8], +} + +impl<'a> SaslFirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(bytes: &'a [u8]) -> Option { + let (method_cstr, tail) = split_cstr(bytes)?; + let method = method_cstr.to_str().ok()?; + + let (len_bytes, message) = split_at_const(tail)?; + let len = u32::from_be_bytes(*len_bytes) as usize; + if len != message.len() { + return None; + } + + Some(Self { method, message }) + } +} + +/// A single SASL message. +/// This struct is deliberately decoupled from lower-level +/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage). +#[derive(Debug)] +pub enum SaslMessage { + /// We expect to see more steps. + Continue(T), + /// This is the final step. + Final(T), +} + +/// This specialized trait provides capabilities akin to +/// [`std::io::Read`]+[`std::io::Write`] in oder to +/// abstract away underlying stream implementations. +pub trait SaslStream { + /// We'd like to use `AsRef<[str]>` here, but afaik there's + /// no cheap way to make [`String`] out of [`bytes::Bytes`]; + /// On the other hand, byte slices are a decent middle ground. + type In: AsRef<[u8]>; + + /// Receive a [SASL](crate::sasl) message from a client. + fn recv(&mut self) -> anyhow::Result; + + /// Send a [SASL](crate::sasl) message to a client. + fn send(&mut self, data: &SaslMessage>) -> anyhow::Result<()>; +} + +impl SaslStream for &mut S { + type In = S::In; + + #[inline(always)] + fn recv(&mut self) -> anyhow::Result { + S::recv(self) + } + + #[inline(always)] + fn send(&mut self, data: &SaslMessage>) -> anyhow::Result<()> { + S::send(self, data) + } +} + +/// Sometimes it's necessary to mix in a message we got from somewhere else. +impl<'a, V: AsRef<[u8]>, S: SaslStream> SaslStream for (Option, S) { + type In = S::In; + + #[inline(always)] + fn recv(&mut self) -> anyhow::Result { + // Try returning a stashed message first + match self.0.take() { + Some(value) => Ok(value), + None => self.1.recv(), + } + } + + #[inline(always)] + fn send(&mut self, data: &SaslMessage>) -> anyhow::Result<()> { + self.1.send(data) + } +} + +/// Fine-grained auth errors help in writing tests. +#[derive(Error, Debug)] +pub enum SaslError { + #[error("failed to authenticate client: {0}")] + AuthenticationFailed(&'static str), + #[error("bad client message")] + BadClientMessage, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +/// A convenient result type for SASL exchange. +pub type Result = std::result::Result; + +/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. +pub trait SaslMechanism: Sized { + /// Produce a server challenge to be sent to the client. + /// This is how this method is called in PostgreSQL (libpq/sasl.h). + fn exchange(self, input: &str) -> Result<(Option, String)>; + + /// Perform SASL message exchange according to the underlying algorithm + /// until user is either authenticated or denied access. + fn authenticate(mut self, mut stream: impl SaslStream) -> Result<()> { + loop { + let msg = stream.recv()?; + let input = std::str::from_utf8(msg.as_ref()).context("bad encoding")?; + + let (this, reply) = self.exchange(input)?; + match this { + Some(this) => { + stream.send(&SaslMessage::Continue(reply))?; + self = this; + } + None => { + stream.send(&SaslMessage::Final(reply))?; + break; + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::ffi::CStr; + + #[test] + fn parse_sasl_first_message() { + let proto = CStr::from_bytes_with_nul(b"SCRAM-SHA-256\0").unwrap(); + let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4".as_bytes(); + let sasl_len = (sasl.len() as u32).to_be_bytes(); + let bytes = [proto.to_bytes_with_nul(), sasl_len.as_ref(), sasl].concat(); + + let password = SaslFirstMessage::parse(&bytes).unwrap(); + assert_eq!(password.method, proto.to_str().unwrap()); + assert_eq!(password.message, sasl); + } +} diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs new file mode 100644 index 0000000000..44519c453f --- /dev/null +++ b/proxy/src/scram.rs @@ -0,0 +1,131 @@ +//! Salted Challenge Response Authentication Mechanism. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +mod channel_binding; +mod key; +mod messages; +mod secret; +mod signature; + +pub use channel_binding::*; +pub use secret::*; + +use crate::sasl::{self, SaslError, SaslMechanism}; +use messages::{ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage}; +use signature::SignatureBuilder; + +pub use self::secret::ScramSecret; + +/// Decode base64 into array without any heap allocations +fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { + let mut bytes = [0u8; N]; + + let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?; + if size != N { + return None; + } + + Some(bytes) +} + +#[derive(Debug)] +enum ScramExchangeServerState { + /// Waiting for [`ClientFirstMessage`]. + Initial, + /// Waiting for [`ClientFinalMessage`]. + SaltSent { + cbind_flag: ChannelBinding, + client_first_message_bare: String, + server_first_message: OwnedServerFirstMessage, + }, +} + +/// Server's side of SCRAM auth algorithm. +#[derive(Debug)] +pub struct ScramExchangeServer<'a> { + state: ScramExchangeServerState, + secret: &'a ScramSecret, +} + +impl<'a> ScramExchangeServer<'a> { + pub fn new(secret: &'a ScramSecret) -> Self { + Self { + state: ScramExchangeServerState::Initial, + secret, + } + } +} + +impl SaslMechanism for ScramExchangeServer<'_> { + fn exchange(mut self, input: &str) -> sasl::Result<(Option, String)> { + use ScramExchangeServerState::*; + match &self.state { + Initial => { + let client_first_message = + ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let server_first_message = client_first_message.build_server_first_message( + // TODO: use secure random + &rand::random(), + &self.secret.salt_base64, + self.secret.iterations, + ); + let msg = server_first_message.as_str().to_owned(); + + self.state = SaltSent { + cbind_flag: client_first_message.cbind_flag.map(str::to_owned), + client_first_message_bare: client_first_message.bare.to_owned(), + server_first_message, + }; + + Ok((Some(self), msg)) + } + SaltSent { + cbind_flag, + client_first_message_bare, + server_first_message, + } => { + let client_final_message = + ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let channel_binding = cbind_flag.encode(|_| { + // TODO: make global design decision regarding the certificate + todo!("fetch TLS certificate data") + }); + + // This might've been caused by a MITM attack + if client_final_message.channel_binding != channel_binding { + return Err(SaslError::AuthenticationFailed("channel binding failed")); + } + + if client_final_message.nonce != server_first_message.nonce() { + return Err(SaslError::AuthenticationFailed("bad nonce")); + } + + let signature_builder = SignatureBuilder { + client_first_message_bare, + server_first_message: server_first_message.as_str(), + client_final_message_without_proof: client_final_message.without_proof, + }; + + let client_key = signature_builder + .build(&self.secret.stored_key) + .derive_client_key(&client_final_message.proof); + + if client_key.sha256() != self.secret.stored_key { + return Err(SaslError::AuthenticationFailed("keys don't match")); + } + + let msg = client_final_message + .build_server_final_message(signature_builder, &self.secret.server_key); + + Ok((None, msg)) + } + } + } +} diff --git a/proxy/src/scram/channel_binding.rs b/proxy/src/scram/channel_binding.rs new file mode 100644 index 0000000000..0ac5009102 --- /dev/null +++ b/proxy/src/scram/channel_binding.rs @@ -0,0 +1,77 @@ +//! Definition and parser for channel binding flag (a part of GS2 header). + +/// Channel binding flag (possibly with params). +#[derive(Debug, PartialEq, Eq)] +pub enum ChannelBinding { + /// Client doesn't support channel binding. + NotSupportedClient, + /// Client thinks server doesn't support channel binding. + NotSupportedServer, + /// Client wants to use this type of channel binding. + Required(T), +} + +impl ChannelBinding { + pub fn map(self, f: impl FnOnce(T) -> R) -> ChannelBinding { + use ChannelBinding::*; + match self { + NotSupportedClient => NotSupportedClient, + NotSupportedServer => NotSupportedServer, + Required(x) => Required(f(x)), + } + } +} + +impl<'a> ChannelBinding<&'a str> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + use ChannelBinding::*; + Some(match input { + "n" => NotSupportedClient, + "y" => NotSupportedServer, + other => Required(other.strip_prefix("p=")?), + }) + } +} + +impl> ChannelBinding { + /// Encode channel binding data as base64 for subsequent checks. + pub fn encode(&self, get_cbind_data: impl FnOnce(&str) -> String) -> String { + use ChannelBinding::*; + match self { + NotSupportedClient => { + // base64::encode("n,,") + "biws".into() + } + NotSupportedServer => { + // base64::encode("y,,") + "eSws".into() + } + Required(s) => { + let s = s.as_ref(); + let msg = format!("p={mode},,{data}", mode = s, data = get_cbind_data(s)); + base64::encode(msg) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn channel_binding_encode() { + use ChannelBinding::*; + + let cases = [ + (NotSupportedClient, base64::encode("n,,")), + (NotSupportedServer, base64::encode("y,,")), + (Required("foo"), base64::encode("p=foo,,bar")), + ]; + + for (cb, input) in cases { + assert_eq!(cb.encode(|_| "bar".to_owned()), input); + } + } +} diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs new file mode 100644 index 0000000000..cae31e3919 --- /dev/null +++ b/proxy/src/scram/key.rs @@ -0,0 +1,40 @@ +//! Tools for client/server/stored keys management. + +use sha2::{Digest, Sha256}; + +/// Faithfully taken from PostgreSQL. +pub const SCRAM_KEY_LEN: usize = 32; + +/// Thin wrapper for byte array. +#[derive(Debug, PartialEq, Eq)] +#[repr(transparent)] +pub struct ScramKey { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl ScramKey { + pub fn sha256(&self) -> ScramKey { + let mut bytes = [0u8; SCRAM_KEY_LEN]; + bytes.copy_from_slice({ + let mut hash = Sha256::new(); + hash.update(&self.bytes); + hash.finalize().as_slice() + }); + + bytes.into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { + #[inline(always)] + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for ScramKey { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs new file mode 100644 index 0000000000..1b8c337e61 --- /dev/null +++ b/proxy/src/scram/messages.rs @@ -0,0 +1,228 @@ +//! Definitions for SCRAM messages. + +use super::base64_decode_array; +use super::channel_binding::ChannelBinding; +use super::key::{ScramKey, SCRAM_KEY_LEN}; +use super::signature::SignatureBuilder; +use std::fmt; +use std::ops::Range; + +/// Faithfully taken from PostgreSQL. +const SCRAM_RAW_NONCE_LEN: usize = 18; + +/// Although we ignore all extensions, we still have to validate the message. +fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option<()> { + for mut chars in parts.map(|s| s.chars()) { + let attr = chars.next()?; + if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) { + return None; + } + let eq = chars.next()?; + if eq != '=' { + return None; + } + } + + Some(()) +} + +#[derive(Debug)] +pub struct ClientFirstMessage<'a> { + /// `client-first-message-bare`. + pub bare: &'a str, + /// Channel binding mode. + pub cbind_flag: ChannelBinding<&'a str>, + /// (Client username)[]. + pub username: &'a str, + /// Client nonce. + pub nonce: &'a str, +} + +impl<'a> ClientFirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let mut parts = input.split(','); + + let cbind_flag = ChannelBinding::parse(parts.next()?)?; + + // PG doesn't support authorization identity, + // so we don't bother defining GS2 header type + let authzid = parts.next()?; + if !authzid.is_empty() { + return None; + } + + // Unfortunately, `parts.as_str()` is unstable + let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1; + let (_, bare) = input.split_at(pos); + + // In theory, these might be preceded by "reserved-mext" (i.e. "m=") + let username = parts.next()?.strip_prefix("n=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + Some(Self { + bare, + cbind_flag, + username, + nonce, + }) + } + + pub fn build_server_first_message( + &self, + nonce: &[u8; SCRAM_RAW_NONCE_LEN], + salt_base64: &str, + iterations: u32, + ) -> OwnedServerFirstMessage { + use std::fmt::Write; + + let mut message = String::new(); + write!(&mut message, "r={}", self.nonce).unwrap(); + base64::encode_config_buf(nonce, base64::STANDARD, &mut message); + let combined_nonce = 2..message.len(); + write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap(); + + // This design guarantees that it's impossible to create a + // server-first-message without receiving a client-first-message + OwnedServerFirstMessage { + message, + nonce: combined_nonce, + } + } +} + +#[derive(Debug)] +pub struct ClientFinalMessage<'a> { + /// `client-final-message-without-proof`. + pub without_proof: &'a str, + /// Channel binding data (base64). + pub channel_binding: &'a str, + /// Combined client & server nonce. + pub nonce: &'a str, + /// Client auth proof. + pub proof: [u8; SCRAM_KEY_LEN], +} + +impl<'a> ClientFinalMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let (without_proof, proof) = input.rsplit_once(',')?; + + let mut parts = without_proof.split(','); + let channel_binding = parts.next()?.strip_prefix("c=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + let proof = base64_decode_array(proof.strip_prefix("p=")?)?; + + Some(Self { + without_proof, + channel_binding, + nonce, + proof, + }) + } + + pub fn build_server_final_message( + &self, + signature_builder: SignatureBuilder, + server_key: &ScramKey, + ) -> String { + let mut buf = String::from("v="); + base64::encode_config_buf( + signature_builder.build(server_key), + base64::STANDARD, + &mut buf, + ); + + buf + } +} + +pub struct OwnedServerFirstMessage { + /// Owned `server-first-message`. + message: String, + /// Slice into `message`. + nonce: Range, +} + +impl OwnedServerFirstMessage { + /// Extract combined nonce from the message. + #[inline(always)] + pub fn nonce(&self) -> &str { + &self.message[self.nonce.clone()] + } + + /// Get reference to a text representation of the message. + #[inline(always)] + pub fn as_str(&self) -> &str { + &self.message + } +} + +impl fmt::Debug for OwnedServerFirstMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerFirstMessage") + .field("message", &self.as_str()) + .field("nonce", &self.nonce()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_client_first_message() { + use ChannelBinding::*; + + // (Almost) real strings captured during debug sessions + let cases = [ + (NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + (NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + ( + Required("tls-server-end-point"), + "p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju", + ), + ]; + + for (cb, input) in cases { + let msg = ClientFirstMessage::parse(input).unwrap(); + + assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.username, "pepe"); + assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.cbind_flag, cb); + } + } + + #[test] + fn parse_client_final_message() { + let input = [ + "c=eSws", + "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU", + "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=", + ] + .join(","); + + let msg = ClientFinalMessage::parse(&input).unwrap(); + assert_eq!( + msg.without_proof, + "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + msg.nonce, + "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + base64::encode(msg.proof), + "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=" + ); + } +} diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs new file mode 100644 index 0000000000..e1e0faf2c2 --- /dev/null +++ b/proxy/src/scram/secret.rs @@ -0,0 +1,65 @@ +//! Tools for SCRAM server secret management. + +use super::base64_decode_array; +use super::key::ScramKey; + +#[derive(Debug)] +pub struct ScramSecret { + pub iterations: u32, + pub salt_base64: String, + pub stored_key: ScramKey, + pub server_key: ScramKey, +} + +impl ScramSecret { + pub fn parse(input: &str) -> Option { + // SCRAM-SHA-256$:$: + let s = input.strip_prefix("SCRAM-SHA-256$")?; + let (params, keys) = s.split_once('$')?; + + let ((iterations, salt), (stored_key, server_key)) = + params.split_once(':').zip(keys.split_once(':'))?; + + let secret = ScramSecret { + iterations: iterations.parse().ok()?, + salt_base64: salt.to_owned(), + stored_key: base64_decode_array(stored_key)?.into(), + server_key: base64_decode_array(server_key)?.into(), + }; + + Some(secret) + } + + pub fn mock() -> Self { + todo!("see auth-scram.c : mock_scram_secret") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_scram_secret() { + let iterations = 4096; + let salt = "+/tQQax7twvwTj64mjBsxQ=="; + let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns="; + let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI="; + + let secret = format!( + "SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}", + iterations = iterations, + salt = salt, + stored_key = stored_key, + server_key = server_key, + ); + + let parsed = ScramSecret::parse(&secret).unwrap(); + assert_eq!(parsed.iterations, iterations); + assert_eq!(parsed.salt_base64, salt); + + // TODO: derive from 'password' + assert_eq!(base64::encode(parsed.stored_key), stored_key); + assert_eq!(base64::encode(parsed.server_key), server_key); + } +} diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs new file mode 100644 index 0000000000..6be5d8d118 --- /dev/null +++ b/proxy/src/scram/signature.rs @@ -0,0 +1,70 @@ +//! Tools for client/server signature management. + +use super::key::{ScramKey, SCRAM_KEY_LEN}; +use hmac::{Hmac, Mac, NewMac}; +use sha2::Sha256; + +#[derive(Debug)] +pub struct SignatureBuilder<'a> { + pub client_first_message_bare: &'a str, + pub server_first_message: &'a str, + pub client_final_message_without_proof: &'a str, +} + +impl SignatureBuilder<'_> { + pub fn build(&self, key: &ScramKey) -> Signature { + let mut mac = Hmac::::new_varkey(key.as_ref()).expect("bad key size"); + + mac.update(self.client_first_message_bare.as_bytes()); + mac.update(b","); + mac.update(self.server_first_message.as_bytes()); + mac.update(b","); + mac.update(self.client_final_message_without_proof.as_bytes()); + + // TODO: maybe newer `hmac` et al already migrated to regular arrays? + let mut signature = [0u8; SCRAM_KEY_LEN]; + signature.copy_from_slice(mac.finalize().into_bytes().as_slice()); + signature.into() + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct Signature { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl Signature { + /// Derive ClientKey from client's signature and proof + pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { + let signature = self.as_ref().iter(); + + // This is how the proof is calculated: + // + // 1. sha256(ClientKey) -> StoredKey + // 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature + // 3. ClientKey ^ ClientSignature -> ClientProof + // + // Step 3 implies that we can restore ClientKey from the proof + // by xoring the latter with the ClientSignature again. Afterwards + // we can check that the presumed ClientKey meets our expectations. + let mut bytes = [0u8; SCRAM_KEY_LEN]; + for (i, value) in signature.zip(proof).map(|(x, y)| x ^ y).enumerate() { + bytes[i] = value; + } + + bytes.into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for Signature { + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for Signature { + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index d55ead93bc..6ee8fbedfb 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -373,9 +373,8 @@ impl PostgresBackend { } AuthType::MD5 => { rand::thread_rng().fill(&mut self.md5_salt); - let md5_salt = self.md5_salt; self.write_message(&BeMessage::AuthenticationMD5Password( - &md5_salt, + self.md5_salt, ))?; self.state = ProtoState::Authentication; } diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index 3ad4f41ee2..685e7df855 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -353,7 +353,8 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { #[derive(Debug)] pub enum BeMessage<'a> { AuthenticationOk, - AuthenticationMD5Password(&'a [u8; 4]), + AuthenticationMD5Password([u8; 4]), + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), AuthenticationCleartextPassword, BackendKeyData(CancelKeyData), BindComplete, @@ -381,6 +382,13 @@ pub enum BeMessage<'a> { KeepAlive(WalSndKeepAlive), } +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + #[derive(Debug)] pub enum BeParameterStatusMessage<'a> { Encoding(&'a str), @@ -552,6 +560,32 @@ impl<'a> BeMessage<'a> { .unwrap(); // write into BytesMut can't fail } + BeMessage::AuthenticationSasl(msg) => { + buf.put_u8(b'R'); + write_body(buf, |buf| { + use BeAuthenticationSaslMessage::*; + match msg { + Methods(methods) => { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods.iter() { + write_cstr(method.as_bytes(), buf)?; + } + buf.put_u8(0); // zero terminator for the list + } + Continue(extra) => { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + } + Final(extra) => { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + } + } + Ok::<_, io::Error>(()) + }) + .unwrap() + } + BeMessage::BackendKeyData(key_data) => { buf.put_u8(b'K'); write_body(buf, |buf| {