diff --git a/Cargo.lock b/Cargo.lock index 1a9e261281..7df1c4ab7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1907,12 +1907,15 @@ name = "proxy" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "base64 0.13.0", "bytes", "clap 3.0.14", "fail", "futures", "hashbrown", "hex", + "hmac 0.10.1", "hyper", "lazy_static", "md5", @@ -1921,16 +1924,20 @@ dependencies = [ "rand", "rcgen", "reqwest", + "routerify 2.2.0", + "rstest", "rustls 0.19.1", "scopeguard", "serde", "serde_json", + "sha2", "socket2", "thiserror", "tokio", "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "tokio-postgres-rustls", "tokio-rustls 0.22.0", + "tokio-stream", "workspace_hack", "zenith_metrics", "zenith_utils", @@ -2130,6 +2137,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "routerify" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9" +dependencies = [ + "http", + "hyper", + "lazy_static", + "percent-encoding", + "regex", +] + [[package]] name = "routerify" version = "3.0.0" @@ -2143,6 +2163,19 @@ dependencies = [ "regex", ] +[[package]] +name = "rstest" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d912f35156a3f99a66ee3e11ac2e0b3f34ac85a07e05263d05a7e2c8810d616f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + [[package]] name = "rusoto_core" version = "0.47.0" @@ -3450,7 +3483,7 @@ dependencies = [ "postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "rand", - "routerify", + "routerify 3.0.0", "rustls 0.19.1", "rustls-split", "serde", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index dc20695884..56b6dd7e20 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -5,12 +5,14 @@ edition = "2021" [dependencies] anyhow = "1.0" +base64 = "0.13.0" bytes = { version = "1.0.1", features = ['serde'] } clap = "3.0" fail = "0.5.0" futures = "0.3.13" hashbrown = "0.11.2" hex = "0.4.3" +hmac = "0.10.1" hyper = "0.14" lazy_static = "1.4.0" md5 = "0.7.0" @@ -18,20 +20,25 @@ parking_lot = "0.11.2" pin-project-lite = "0.2.7" rand = "0.8.3" reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } +routerify = "2" rustls = "0.19.1" scopeguard = "1.1.0" serde = "1" serde_json = "1" +sha2 = "0.9.8" socket2 = "0.4.4" -thiserror = "1.0" +thiserror = "1.0.30" tokio = { version = "1.17", features = ["macros"] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } tokio-rustls = "0.22.0" +tokio-stream = "0.1.8" zenith_utils = { path = "../zenith_utils" } zenith_metrics = { path = "../zenith_metrics" } workspace_hack = { version = "0.1", path = "../workspace_hack" } [dev-dependencies] -tokio-postgres-rustls = "0.8.0" +async-trait = "0.1" rcgen = "0.8.14" +rstest = "0.12" +tokio-postgres-rustls = "0.8.0" diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index e8fe65c081..bda14d67a1 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,14 +1,24 @@ +mod credentials; + +#[cfg(test)] +mod flow; + use crate::compute::DatabaseInfo; use crate::config::ProxyConfig; use crate::cplane_api::{self, CPlaneApi}; use crate::error::UserFacingError; use crate::stream::PqStream; use crate::waiters; -use std::collections::HashMap; +use std::io; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; +pub use credentials::ClientCredentials; + +#[cfg(test)] +pub use flow::*; + /// Common authentication error. #[derive(Debug, Error)] pub enum AuthErrorImpl { @@ -16,13 +26,17 @@ pub enum AuthErrorImpl { #[error(transparent)] Console(#[from] cplane_api::AuthError), + #[cfg(test)] + #[error(transparent)] + Sasl(#[from] crate::sasl::Error), + /// For passwords that couldn't be processed by [`parse_password`]. #[error("Malformed password message")] MalformedPassword, /// Errors produced by [`PqStream`]. #[error(transparent)] - Io(#[from] std::io::Error), + Io(#[from] io::Error), } impl AuthErrorImpl { @@ -67,70 +81,6 @@ impl UserFacingError for AuthError { } } -#[derive(Debug, Error)] -pub enum ClientCredsParseError { - #[error("Parameter `{0}` is missing in startup packet")] - MissingKey(&'static str), -} - -impl UserFacingError for ClientCredsParseError {} - -/// Various client credentials which we use for authentication. -#[derive(Debug, PartialEq, Eq)] -pub struct ClientCredentials { - pub user: String, - pub dbname: String, -} - -impl TryFrom> for ClientCredentials { - type Error = ClientCredsParseError; - - fn try_from(mut value: HashMap) -> Result { - let mut get_param = |key| { - value - .remove(key) - .ok_or(ClientCredsParseError::MissingKey(key)) - }; - - let user = get_param("user")?; - let db = get_param("database")?; - - Ok(Self { user, dbname: db }) - } -} - -impl ClientCredentials { - /// Use credentials to authenticate the user. - pub async fn authenticate( - self, - config: &ProxyConfig, - client: &mut PqStream, - ) -> Result { - fail::fail_point!("proxy-authenticate", |_| { - Err(AuthError::auth_failed("failpoint triggered")) - }); - - use crate::config::ClientAuthMethod::*; - use crate::config::RouterConfig::*; - match &config.router_config { - Static { host, port } => handle_static(host.clone(), *port, client, self).await, - Dynamic(Mixed) => { - if self.user.ends_with("@zenith") { - handle_existing_user(config, client, self).await - } else { - handle_new_user(config, client).await - } - } - Dynamic(Password) => handle_existing_user(config, client, self).await, - Dynamic(Link) => handle_new_user(config, client).await, - } - } -} - -fn new_psql_session_id() -> String { - hex::encode(rand::random::<[u8; 8]>()) -} - async fn handle_static( host: String, port: u16, @@ -169,7 +119,7 @@ async fn handle_existing_user( let md5_salt = rand::random(); client - .write_message(&Be::AuthenticationMD5Password(&md5_salt)) + .write_message(&Be::AuthenticationMD5Password(md5_salt)) .await?; // Read client's password hash @@ -213,6 +163,10 @@ async fn handle_new_user( Ok(db_info) } +fn new_psql_session_id() -> String { + hex::encode(rand::random::<[u8; 8]>()) +} + fn parse_password(bytes: &[u8]) -> Option<&str> { std::str::from_utf8(bytes).ok()?.strip_suffix('\0') } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs new file mode 100644 index 0000000000..7c8ba28622 --- /dev/null +++ b/proxy/src/auth/credentials.rs @@ -0,0 +1,70 @@ +//! User credentials used in authentication. + +use super::AuthError; +use crate::compute::DatabaseInfo; +use crate::config::ProxyConfig; +use crate::error::UserFacingError; +use crate::stream::PqStream; +use std::collections::HashMap; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[derive(Debug, Error)] +pub enum ClientCredsParseError { + #[error("Parameter `{0}` is missing in startup packet")] + MissingKey(&'static str), +} + +impl UserFacingError for ClientCredsParseError {} + +/// Various client credentials which we use for authentication. +#[derive(Debug, PartialEq, Eq)] +pub struct ClientCredentials { + pub user: String, + pub dbname: String, +} + +impl TryFrom> for ClientCredentials { + type Error = ClientCredsParseError; + + fn try_from(mut value: HashMap) -> Result { + let mut get_param = |key| { + value + .remove(key) + .ok_or(ClientCredsParseError::MissingKey(key)) + }; + + let user = get_param("user")?; + let db = get_param("database")?; + + Ok(Self { user, dbname: db }) + } +} + +impl ClientCredentials { + /// Use credentials to authenticate the user. + pub async fn authenticate( + self, + config: &ProxyConfig, + client: &mut PqStream, + ) -> Result { + fail::fail_point!("proxy-authenticate", |_| { + Err(AuthError::auth_failed("failpoint triggered")) + }); + + use crate::config::ClientAuthMethod::*; + use crate::config::RouterConfig::*; + match &config.router_config { + Static { host, port } => super::handle_static(host.clone(), *port, client, self).await, + Dynamic(Mixed) => { + if self.user.ends_with("@zenith") { + super::handle_existing_user(config, client, self).await + } else { + super::handle_new_user(config, client).await + } + } + Dynamic(Password) => super::handle_existing_user(config, client, self).await, + Dynamic(Link) => super::handle_new_user(config, client).await, + } + } +} diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs new file mode 100644 index 0000000000..0fafaa2f47 --- /dev/null +++ b/proxy/src/auth/flow.rs @@ -0,0 +1,102 @@ +//! Main authentication flow. + +use super::{AuthError, AuthErrorImpl}; +use crate::stream::PqStream; +use crate::{sasl, scram}; +use std::io; +use tokio::io::{AsyncRead, AsyncWrite}; +use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; + +/// Every authentication selector is supposed to implement this trait. +pub trait AuthMethod { + /// Any authentication selector should provide initial backend message + /// containing auth method name and parameters, e.g. md5 salt. + fn first_message(&self) -> BeMessage<'_>; +} + +/// Initial state of [`AuthFlow`]. +pub struct Begin; + +/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. +pub struct Scram<'a>(pub &'a scram::ServerSecret); + +impl AuthMethod for Scram<'_> { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + } +} + +/// Use password-based auth in [`AuthFlow`]. +pub struct Md5( + /// Salt for client. + pub [u8; 4], +); + +impl AuthMethod for Md5 { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationMD5Password(self.0) + } +} + +/// This wrapper for [`PqStream`] performs client authentication. +#[must_use] +pub struct AuthFlow<'a, Stream, State> { + /// The underlying stream which implements libpq's protocol. + stream: &'a mut PqStream, + /// State might contain ancillary data (see [`AuthFlow::begin`]). + state: State, +} + +/// Initial state of the stream wrapper. +impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { + /// Create a new wrapper for client authentication. + pub fn new(stream: &'a mut PqStream) -> Self { + Self { + stream, + state: Begin, + } + } + + /// Move to the next step by sending auth method's name & params to client. + pub async fn begin(self, method: M) -> io::Result> { + self.stream.write_message(&method.first_message()).await?; + + Ok(AuthFlow { + stream: self.stream, + state: method, + }) + } +} + +/// Stream wrapper for handling simple MD5 password auth. +impl AuthFlow<'_, S, Md5> { + /// Perform user authentication. Raise an error in case authentication failed. + #[allow(unused)] + pub async fn authenticate(self) -> Result<(), AuthError> { + unimplemented!("MD5 auth flow is yet to be implemented"); + } +} + +/// Stream wrapper for handling [SCRAM](crate::scram) auth. +impl AuthFlow<'_, S, Scram<'_>> { + /// Perform user authentication. Raise an error in case authentication failed. + pub async fn authenticate(self) -> Result<(), AuthError> { + // Initial client message contains the chosen auth method's name. + let msg = self.stream.read_password_message().await?; + let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; + + // Currently, the only supported SASL method is SCRAM. + if !scram::METHODS.contains(&sasl.method) { + return Err(AuthErrorImpl::auth_failed("method not supported").into()); + } + + let secret = self.state.0; + sasl::SaslStream::new(self.stream, sasl.message) + .authenticate(scram::Exchange::new(secret, rand::random, None)) + .await?; + + Ok(()) + } +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index bd99d0a639..862152bb7b 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -1,19 +1,8 @@ -/// -/// Postgres protocol proxy/router. -/// -/// This service listens psql port and can check auth via external service -/// (control plane API in our case) and can create new databases and accounts -/// in somewhat transparent manner (again via communication with control plane API). -/// -use anyhow::{bail, Context}; -use clap::{App, Arg}; -use config::ProxyConfig; -use futures::FutureExt; -use std::future::Future; -use tokio::{net::TcpListener, task::JoinError}; -use zenith_utils::GIT_VERSION; - -use crate::config::{ClientAuthMethod, RouterConfig}; +//! Postgres protocol proxy/router. +//! +//! This service listens psql port and can check auth via external service +//! (control plane API in our case) and can create new databases and accounts +//! in somewhat transparent manner (again via communication with control plane API). mod auth; mod cancellation; @@ -27,6 +16,24 @@ mod proxy; mod stream; mod waiters; +// Currently SCRAM is only used in tests +#[cfg(test)] +mod parse; +#[cfg(test)] +mod sasl; +#[cfg(test)] +mod scram; + +use anyhow::{bail, Context}; +use clap::{App, Arg}; +use config::ProxyConfig; +use futures::FutureExt; +use std::future::Future; +use tokio::{net::TcpListener, task::JoinError}; +use zenith_utils::GIT_VERSION; + +use crate::config::{ClientAuthMethod, RouterConfig}; + /// Flattens `Result>` into `Result`. async fn flatten_err( f: impl Future, JoinError>>, diff --git a/proxy/src/parse.rs b/proxy/src/parse.rs new file mode 100644 index 0000000000..8a05ff9c82 --- /dev/null +++ b/proxy/src/parse.rs @@ -0,0 +1,18 @@ +//! Small parsing helpers. + +use std::convert::TryInto; +use std::ffi::CStr; + +pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { + let pos = bytes.iter().position(|&x| x == 0)?; + let (cstr, other) = bytes.split_at(pos + 1); + // SAFETY: we've already checked that there's a terminator + Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other)) +} + +pub fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { + (bytes.len() >= N).then(|| { + let (head, tail) = bytes.split_at(N); + (head.try_into().unwrap(), tail) + }) +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 81581b5cf1..5b662f4c69 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -119,7 +119,6 @@ async fn handshake( // We can't perform TLS handshake without a config let enc = tls.is_some(); stream.write_message(&Be::EncryptionResponse(enc)).await?; - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. @@ -219,32 +218,14 @@ impl Client { #[cfg(test)] mod tests { use super::*; - - use tokio::io::DuplexStream; + use crate::{auth, scram}; + use async_trait::async_trait; + use rstest::rstest; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres_rustls::MakeRustlsConnect; - async fn dummy_proxy( - client: impl AsyncRead + AsyncWrite + Unpin, - tls: Option, - ) -> anyhow::Result<()> { - let cancel_map = CancelMap::default(); - - // TODO: add some infra + tests for credentials - let (mut stream, _creds) = handshake(client, tls, &cancel_map) - .await? - .context("no stream")?; - - stream - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? - .write_message(&BeMessage::ReadyForQuery) - .await?; - - Ok(()) - } - + /// Generate a set of TLS certificates: CA + server. fn generate_certs( hostname: &str, ) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> { @@ -262,19 +243,115 @@ mod tests { )) } + struct ClientConfig<'a> { + config: rustls::ClientConfig, + hostname: &'a str, + } + + impl ClientConfig<'_> { + fn make_tls_connect( + self, + ) -> anyhow::Result> { + let mut mk = MakeRustlsConnect::new(self.config); + let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; + Ok(tls) + } + } + + /// Generate TLS certificates and build rustls configs for client and server. + fn generate_tls_config( + hostname: &str, + ) -> anyhow::Result<(ClientConfig<'_>, Arc)> { + let (ca, cert, key) = generate_certs(hostname)?; + + let server_config = { + let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(vec![cert], key)?; + config.into() + }; + + let client_config = { + let mut config = rustls::ClientConfig::new(); + config.root_store.add(&ca)?; + ClientConfig { config, hostname } + }; + + Ok((client_config, server_config)) + } + + #[async_trait] + trait TestAuth: Sized { + async fn authenticate( + self, + _stream: &mut PqStream>, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + struct NoAuth; + impl TestAuth for NoAuth {} + + struct Scram(scram::ServerSecret); + + impl Scram { + fn new(password: &str) -> anyhow::Result { + let salt = rand::random::<[u8; 16]>(); + let secret = scram::ServerSecret::build(password, &salt, 256) + .context("failed to generate scram secret")?; + Ok(Scram(secret)) + } + + fn mock(user: &str) -> Self { + let salt = rand::random::<[u8; 32]>(); + Scram(scram::ServerSecret::mock(user, &salt)) + } + } + + #[async_trait] + impl TestAuth for Scram { + async fn authenticate( + self, + stream: &mut PqStream>, + ) -> anyhow::Result<()> { + auth::AuthFlow::new(stream) + .begin(auth::Scram(&self.0)) + .await? + .authenticate() + .await?; + + Ok(()) + } + } + + /// A dummy proxy impl which performs a handshake and reports auth success. + async fn dummy_proxy( + client: impl AsyncRead + AsyncWrite + Unpin + Send, + tls: Option, + auth: impl TestAuth + Send, + ) -> anyhow::Result<()> { + let cancel_map = CancelMap::default(); + let (mut stream, _creds) = handshake(client, tls, &cancel_map) + .await? + .context("handshake failed")?; + + auth.authenticate(&mut stream).await?; + + stream + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message(&BeMessage::ReadyForQuery) + .await?; + + Ok(()) + } + #[tokio::test] async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let server_config = { - let (_ca, cert, key) = generate_certs("localhost")?; - - let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(vec![cert], key)?; - config - }; - - let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); + let (_, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let client_err = tokio_postgres::Config::new() .user("john_doe") @@ -301,30 +378,14 @@ mod tests { async fn handshake_tls() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (ca, cert, key) = generate_certs("localhost")?; - - let server_config = { - let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(vec![cert], key)?; - config - }; - - let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); - - let client_config = { - let mut config = rustls::ClientConfig::new(); - config.root_store.add(&ca)?; - config - }; - - let mut mk = MakeRustlsConnect::new(client_config); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk, "localhost")?; + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) - .connect_raw(server, tls) + .connect_raw(server, client_config.make_tls_connect()?) .await?; proxy.await? @@ -334,7 +395,7 @@ mod tests { async fn handshake_raw() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let proxy = tokio::spawn(dummy_proxy(client, None)); + let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() .user("john_doe") @@ -350,7 +411,7 @@ mod tests { async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let proxy = tokio::spawn(dummy_proxy(client, None)); + let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); let client_err = tokio_postgres::Config::new() .ssl_mode(SslMode::Disable) @@ -391,4 +452,66 @@ mod tests { Ok(()) } + + #[rstest] + #[case("password_foo")] + #[case("pwd-bar")] + #[case("")] + #[tokio::test] + async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::new(password)?, + )); + + let (_client, _conn) = tokio_postgres::Config::new() + .user("user") + .dbname("db") + .password(password) + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await?; + + proxy.await? + } + + #[tokio::test] + async fn scram_auth_mock() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::mock("user"), + )); + + use rand::{distributions::Alphanumeric, Rng}; + let password: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(rand::random::() as usize) + .map(char::from) + .collect(); + + let _client_err = tokio_postgres::Config::new() + .user("user") + .dbname("db") + .password(&password) // no password will match the mocked secret + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await + .err() // -> Option + .context("client shouldn't be able to connect")?; + + let _server_err = proxy + .await? + .err() // -> Option + .context("server shouldn't accept client")?; + + Ok(()) + } } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs new file mode 100644 index 0000000000..70a4d9946a --- /dev/null +++ b/proxy/src/sasl.rs @@ -0,0 +1,47 @@ +//! Simple Authentication and Security Layer. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +mod channel_binding; +mod messages; +mod stream; + +use std::io; +use thiserror::Error; + +pub use channel_binding::ChannelBinding; +pub use messages::FirstMessage; +pub use stream::SaslStream; + +/// Fine-grained auth errors help in writing tests. +#[derive(Error, Debug)] +pub enum Error { + #[error("Failed to authenticate client: {0}")] + AuthenticationFailed(&'static str), + + #[error("Channel binding failed: {0}")] + ChannelBindingFailed(&'static str), + + #[error("Unsupported channel binding method: {0}")] + ChannelBindingBadMethod(Box), + + #[error("Bad client message")] + BadClientMessage, + + #[error(transparent)] + Io(#[from] io::Error), +} + +/// A convenient result type for SASL exchange. +pub type Result = std::result::Result; + +/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. +pub trait Mechanism: Sized { + /// Produce a server challenge to be sent to the client. + /// This is how this method is called in PostgreSQL (`libpq/sasl.h`). + fn exchange(self, input: &str) -> Result<(Option, String)>; +} diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs new file mode 100644 index 0000000000..776adabe55 --- /dev/null +++ b/proxy/src/sasl/channel_binding.rs @@ -0,0 +1,85 @@ +//! Definition and parser for channel binding flag (a part of the `GS2` header). + +/// Channel binding flag (possibly with params). +#[derive(Debug, PartialEq, Eq)] +pub enum ChannelBinding { + /// Client doesn't support channel binding. + NotSupportedClient, + /// Client thinks server doesn't support channel binding. + NotSupportedServer, + /// Client wants to use this type of channel binding. + Required(T), +} + +impl ChannelBinding { + pub fn and_then(self, f: impl FnOnce(T) -> Result) -> Result, E> { + use ChannelBinding::*; + Ok(match self { + NotSupportedClient => NotSupportedClient, + NotSupportedServer => NotSupportedServer, + Required(x) => Required(f(x)?), + }) + } +} + +impl<'a> ChannelBinding<&'a str> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + use ChannelBinding::*; + Some(match input { + "n" => NotSupportedClient, + "y" => NotSupportedServer, + other => Required(other.strip_prefix("p=")?), + }) + } +} + +impl ChannelBinding { + /// Encode channel binding data as base64 for subsequent checks. + pub fn encode( + &self, + get_cbind_data: impl FnOnce(&T) -> Result, + ) -> Result, E> { + use ChannelBinding::*; + Ok(match self { + NotSupportedClient => { + // base64::encode("n,,") + "biws".into() + } + NotSupportedServer => { + // base64::encode("y,,") + "eSws".into() + } + Required(mode) => { + let msg = format!( + "p={mode},,{data}", + mode = mode, + data = get_cbind_data(mode)? + ); + base64::encode(msg).into() + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn channel_binding_encode() -> anyhow::Result<()> { + use ChannelBinding::*; + + let cases = [ + (NotSupportedClient, base64::encode("n,,")), + (NotSupportedServer, base64::encode("y,,")), + (Required("foo"), base64::encode("p=foo,,bar")), + ]; + + for (cb, input) in cases { + assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input); + } + + Ok(()) + } +} diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs new file mode 100644 index 0000000000..b1ae8cc426 --- /dev/null +++ b/proxy/src/sasl/messages.rs @@ -0,0 +1,67 @@ +//! Definitions for SASL messages. + +use crate::parse::{split_at_const, split_cstr}; +use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage}; + +/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage). +#[derive(Debug)] +pub struct FirstMessage<'a> { + /// Authentication method, e.g. `"SCRAM-SHA-256"`. + pub method: &'a str, + /// Initial client message. + pub message: &'a str, +} + +impl<'a> FirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(bytes: &'a [u8]) -> Option { + let (method_cstr, tail) = split_cstr(bytes)?; + let method = method_cstr.to_str().ok()?; + + let (len_bytes, bytes) = split_at_const(tail)?; + let len = u32::from_be_bytes(*len_bytes) as usize; + if len != bytes.len() { + return None; + } + + let message = std::str::from_utf8(bytes).ok()?; + Some(Self { method, message }) + } +} + +/// A single SASL message. +/// This struct is deliberately decoupled from lower-level +/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage). +#[derive(Debug)] +pub(super) enum ServerMessage { + /// We expect to see more steps. + Continue(T), + /// This is the final step. + Final(T), +} + +impl<'a> ServerMessage<&'a str> { + pub(super) fn to_reply(&self) -> BeMessage<'a> { + use BeAuthenticationSaslMessage::*; + BeMessage::AuthenticationSasl(match self { + ServerMessage::Continue(s) => Continue(s.as_bytes()), + ServerMessage::Final(s) => Final(s.as_bytes()), + }) + } +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_sasl_first_message() { + let proto = "SCRAM-SHA-256"; + let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4"; + let sasl_len = (sasl.len() as u32).to_be_bytes(); + let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat(); + + let password = FirstMessage::parse(&bytes).unwrap(); + assert_eq!(password.method, proto); + assert_eq!(password.message, sasl); + } +} diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs new file mode 100644 index 0000000000..03649b8d11 --- /dev/null +++ b/proxy/src/sasl/stream.rs @@ -0,0 +1,70 @@ +//! Abstraction for the string-oriented SASL protocols. + +use super::{messages::ServerMessage, Mechanism}; +use crate::stream::PqStream; +use std::io; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// Abstracts away all peculiarities of the libpq's protocol. +pub struct SaslStream<'a, S> { + /// The underlying stream. + stream: &'a mut PqStream, + /// Current password message we received from client. + current: bytes::Bytes, + /// First SASL message produced by client. + first: Option<&'a str>, +} + +impl<'a, S> SaslStream<'a, S> { + pub fn new(stream: &'a mut PqStream, first: &'a str) -> Self { + Self { + stream, + current: bytes::Bytes::new(), + first: Some(first), + } + } +} + +impl SaslStream<'_, S> { + // Receive a new SASL message from the client. + async fn recv(&mut self) -> io::Result<&str> { + if let Some(first) = self.first.take() { + return Ok(first); + } + + self.current = self.stream.read_password_message().await?; + let s = std::str::from_utf8(&self.current) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + + Ok(s) + } +} + +impl SaslStream<'_, S> { + // Send a SASL message to the client. + async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { + self.stream.write_message(&msg.to_reply()).await?; + Ok(()) + } +} + +impl SaslStream<'_, S> { + /// Perform SASL message exchange according to the underlying algorithm + /// until user is either authenticated or denied access. + pub async fn authenticate(mut self, mut mechanism: impl Mechanism) -> super::Result<()> { + loop { + let input = self.recv().await?; + let (moved, reply) = mechanism.exchange(input)?; + match moved { + Some(moved) => { + self.send(&ServerMessage::Continue(&reply)).await?; + mechanism = moved; + } + None => { + self.send(&ServerMessage::Final(&reply)).await?; + return Ok(()); + } + } + } + } +} diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs new file mode 100644 index 0000000000..f007f3e0b6 --- /dev/null +++ b/proxy/src/scram.rs @@ -0,0 +1,59 @@ +//! Salted Challenge Response Authentication Mechanism. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +mod exchange; +mod key; +mod messages; +mod password; +mod secret; +mod signature; + +pub use secret::*; + +pub use exchange::Exchange; +pub use secret::ServerSecret; + +use hmac::{Hmac, Mac, NewMac}; +use sha2::{Digest, Sha256}; + +// TODO: add SCRAM-SHA-256-PLUS +/// A list of supported SCRAM methods. +pub const METHODS: &[&str] = &["SCRAM-SHA-256"]; + +/// Decode base64 into array without any heap allocations +fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { + let mut bytes = [0u8; N]; + + let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?; + if size != N { + return None; + } + + Some(bytes) +} + +/// This function essentially is `Hmac(sha256, key, input)`. +/// Further reading: . +fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator) -> [u8; 32] { + let mut mac = Hmac::::new_varkey(key).expect("bad key size"); + parts.into_iter().for_each(|s| mac.update(s)); + + // TODO: maybe newer `hmac` et al already migrated to regular arrays? + let mut result = [0u8; 32]; + result.copy_from_slice(mac.finalize().into_bytes().as_slice()); + result +} + +fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { + let mut hasher = Sha256::new(); + parts.into_iter().for_each(|s| hasher.update(s)); + + let mut result = [0u8; 32]; + result.copy_from_slice(hasher.finalize().as_slice()); + result +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs new file mode 100644 index 0000000000..5a986b965a --- /dev/null +++ b/proxy/src/scram/exchange.rs @@ -0,0 +1,134 @@ +//! Implementation of the SCRAM authentication algorithm. + +use super::messages::{ + ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, +}; +use super::secret::ServerSecret; +use super::signature::SignatureBuilder; +use crate::sasl::{self, ChannelBinding, Error as SaslError}; + +/// The only channel binding mode we currently support. +#[derive(Debug)] +struct TlsServerEndPoint; + +impl std::fmt::Display for TlsServerEndPoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "tls-server-end-point") + } +} + +impl std::str::FromStr for TlsServerEndPoint { + type Err = sasl::Error; + + fn from_str(s: &str) -> Result { + match s { + "tls-server-end-point" => Ok(TlsServerEndPoint), + _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())), + } + } +} + +#[derive(Debug)] +enum ExchangeState { + /// Waiting for [`ClientFirstMessage`]. + Initial, + /// Waiting for [`ClientFinalMessage`]. + SaltSent { + cbind_flag: ChannelBinding, + client_first_message_bare: String, + server_first_message: OwnedServerFirstMessage, + }, +} + +/// Server's side of SCRAM auth algorithm. +#[derive(Debug)] +pub struct Exchange<'a> { + state: ExchangeState, + secret: &'a ServerSecret, + nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], + cert_digest: Option<&'a [u8]>, +} + +impl<'a> Exchange<'a> { + pub fn new( + secret: &'a ServerSecret, + nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], + cert_digest: Option<&'a [u8]>, + ) -> Self { + Self { + state: ExchangeState::Initial, + secret, + nonce, + cert_digest, + } + } +} + +impl sasl::Mechanism for Exchange<'_> { + fn exchange(mut self, input: &str) -> sasl::Result<(Option, String)> { + use ExchangeState::*; + match &self.state { + Initial => { + let client_first_message = + ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let server_first_message = client_first_message.build_server_first_message( + &(self.nonce)(), + &self.secret.salt_base64, + self.secret.iterations, + ); + let msg = server_first_message.as_str().to_owned(); + + self.state = SaltSent { + cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?, + client_first_message_bare: client_first_message.bare.to_owned(), + server_first_message, + }; + + Ok((Some(self), msg)) + } + SaltSent { + cbind_flag, + client_first_message_bare, + server_first_message, + } => { + let client_final_message = + ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let channel_binding = cbind_flag.encode(|_| { + self.cert_digest + .map(base64::encode) + .ok_or(SaslError::ChannelBindingFailed("no cert digest provided")) + })?; + + // This might've been caused by a MITM attack + if client_final_message.channel_binding != channel_binding { + return Err(SaslError::ChannelBindingFailed("data mismatch")); + } + + if client_final_message.nonce != server_first_message.nonce() { + return Err(SaslError::AuthenticationFailed("bad nonce")); + } + + let signature_builder = SignatureBuilder { + client_first_message_bare, + server_first_message: server_first_message.as_str(), + client_final_message_without_proof: client_final_message.without_proof, + }; + + let client_key = signature_builder + .build(&self.secret.stored_key) + .derive_client_key(&client_final_message.proof); + + if client_key.sha256() != self.secret.stored_key { + return Err(SaslError::AuthenticationFailed("keys don't match")); + } + + let msg = client_final_message + .build_server_final_message(signature_builder, &self.secret.server_key); + + Ok((None, msg)) + } + } + } +} diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs new file mode 100644 index 0000000000..1c13471bc3 --- /dev/null +++ b/proxy/src/scram/key.rs @@ -0,0 +1,33 @@ +//! Tools for client/server/stored key management. + +/// Faithfully taken from PostgreSQL. +pub const SCRAM_KEY_LEN: usize = 32; + +/// One of the keys derived from the [password](super::password::SaltedPassword). +/// We use the same structure for all keys, i.e. +/// `ClientKey`, `StoredKey`, and `ServerKey`. +#[derive(Default, Debug, PartialEq, Eq)] +#[repr(transparent)] +pub struct ScramKey { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl ScramKey { + pub fn sha256(&self) -> Self { + super::sha256([self.as_ref()]).into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { + #[inline(always)] + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for ScramKey { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs new file mode 100644 index 0000000000..f6e6133adf --- /dev/null +++ b/proxy/src/scram/messages.rs @@ -0,0 +1,232 @@ +//! Definitions for SCRAM messages. + +use super::base64_decode_array; +use super::key::{ScramKey, SCRAM_KEY_LEN}; +use super::signature::SignatureBuilder; +use crate::sasl::ChannelBinding; +use std::fmt; +use std::ops::Range; + +/// Faithfully taken from PostgreSQL. +pub const SCRAM_RAW_NONCE_LEN: usize = 18; + +/// Although we ignore all extensions, we still have to validate the message. +fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option<()> { + for mut chars in parts.map(|s| s.chars()) { + let attr = chars.next()?; + if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) { + return None; + } + let eq = chars.next()?; + if eq != '=' { + return None; + } + } + + Some(()) +} + +#[derive(Debug)] +pub struct ClientFirstMessage<'a> { + /// `client-first-message-bare`. + pub bare: &'a str, + /// Channel binding mode. + pub cbind_flag: ChannelBinding<&'a str>, + /// (Client username)[]. + pub username: &'a str, + /// Client nonce. + pub nonce: &'a str, +} + +impl<'a> ClientFirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let mut parts = input.split(','); + + let cbind_flag = ChannelBinding::parse(parts.next()?)?; + + // PG doesn't support authorization identity, + // so we don't bother defining GS2 header type + let authzid = parts.next()?; + if !authzid.is_empty() { + return None; + } + + // Unfortunately, `parts.as_str()` is unstable + let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1; + let (_, bare) = input.split_at(pos); + + // In theory, these might be preceded by "reserved-mext" (i.e. "m=") + let username = parts.next()?.strip_prefix("n=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + Some(Self { + bare, + cbind_flag, + username, + nonce, + }) + } + + /// Build a response to [`ClientFirstMessage`]. + pub fn build_server_first_message( + &self, + nonce: &[u8; SCRAM_RAW_NONCE_LEN], + salt_base64: &str, + iterations: u32, + ) -> OwnedServerFirstMessage { + use std::fmt::Write; + + let mut message = String::new(); + write!(&mut message, "r={}", self.nonce).unwrap(); + base64::encode_config_buf(nonce, base64::STANDARD, &mut message); + let combined_nonce = 2..message.len(); + write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap(); + + // This design guarantees that it's impossible to create a + // server-first-message without receiving a client-first-message + OwnedServerFirstMessage { + message, + nonce: combined_nonce, + } + } +} + +#[derive(Debug)] +pub struct ClientFinalMessage<'a> { + /// `client-final-message-without-proof`. + pub without_proof: &'a str, + /// Channel binding data (base64). + pub channel_binding: &'a str, + /// Combined client & server nonce. + pub nonce: &'a str, + /// Client auth proof. + pub proof: [u8; SCRAM_KEY_LEN], +} + +impl<'a> ClientFinalMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let (without_proof, proof) = input.rsplit_once(',')?; + + let mut parts = without_proof.split(','); + let channel_binding = parts.next()?.strip_prefix("c=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + let proof = base64_decode_array(proof.strip_prefix("p=")?)?; + + Some(Self { + without_proof, + channel_binding, + nonce, + proof, + }) + } + + /// Build a response to [`ClientFinalMessage`]. + pub fn build_server_final_message( + &self, + signature_builder: SignatureBuilder, + server_key: &ScramKey, + ) -> String { + let mut buf = String::from("v="); + base64::encode_config_buf( + signature_builder.build(server_key), + base64::STANDARD, + &mut buf, + ); + + buf + } +} + +/// We need to keep a convenient representation of this +/// message for the next authentication step. +pub struct OwnedServerFirstMessage { + /// Owned `server-first-message`. + message: String, + /// Slice into `message`. + nonce: Range, +} + +impl OwnedServerFirstMessage { + /// Extract combined nonce from the message. + #[inline(always)] + pub fn nonce(&self) -> &str { + &self.message[self.nonce.clone()] + } + + /// Get reference to a text representation of the message. + #[inline(always)] + pub fn as_str(&self) -> &str { + &self.message + } +} + +impl fmt::Debug for OwnedServerFirstMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerFirstMessage") + .field("message", &self.as_str()) + .field("nonce", &self.nonce()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_client_first_message() { + use ChannelBinding::*; + + // (Almost) real strings captured during debug sessions + let cases = [ + (NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + (NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + ( + Required("tls-server-end-point"), + "p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju", + ), + ]; + + for (cb, input) in cases { + let msg = ClientFirstMessage::parse(input).unwrap(); + + assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.username, "pepe"); + assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.cbind_flag, cb); + } + } + + #[test] + fn parse_client_final_message() { + let input = [ + "c=eSws", + "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU", + "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=", + ] + .join(","); + + let msg = ClientFinalMessage::parse(&input).unwrap(); + assert_eq!( + msg.without_proof, + "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + msg.nonce, + "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + base64::encode(msg.proof), + "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=" + ); + } +} diff --git a/proxy/src/scram/password.rs b/proxy/src/scram/password.rs new file mode 100644 index 0000000000..656780d853 --- /dev/null +++ b/proxy/src/scram/password.rs @@ -0,0 +1,48 @@ +//! Password hashing routines. + +use super::key::ScramKey; + +pub const SALTED_PASSWORD_LEN: usize = 32; + +/// Salted hashed password is essential for [key](super::key) derivation. +#[repr(transparent)] +pub struct SaltedPassword { + bytes: [u8; SALTED_PASSWORD_LEN], +} + +impl SaltedPassword { + /// See `scram-common.c : scram_SaltedPassword` for details. + /// Further reading: (see `PBKDF2`). + pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword { + let one = 1_u32.to_be_bytes(); // magic + + let mut current = super::hmac_sha256(password, [salt, &one]); + let mut result = current; + for _ in 1..iterations { + current = super::hmac_sha256(password, [current.as_ref()]); + // TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094 + for (i, x) in current.iter().enumerate() { + result[i] ^= x; + } + } + + result.into() + } + + /// Derive `ClientKey` from a salted hashed password. + pub fn client_key(&self) -> ScramKey { + super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into() + } + + /// Derive `ServerKey` from a salted hashed password. + pub fn server_key(&self) -> ScramKey { + super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into() + } +} + +impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword { + #[inline(always)] + fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self { + Self { bytes } + } +} diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs new file mode 100644 index 0000000000..e8d180bcdd --- /dev/null +++ b/proxy/src/scram/secret.rs @@ -0,0 +1,116 @@ +//! Tools for SCRAM server secret management. + +use super::base64_decode_array; +use super::key::ScramKey; + +/// Server secret is produced from [password](super::password::SaltedPassword) +/// and is used throughout the authentication process. +#[derive(Debug)] +pub struct ServerSecret { + /// Number of iterations for `PBKDF2` function. + pub iterations: u32, + /// Salt used to hash user's password. + pub salt_base64: String, + /// Hashed `ClientKey`. + pub stored_key: ScramKey, + /// Used by client to verify server's signature. + pub server_key: ScramKey, +} + +impl ServerSecret { + pub fn parse(input: &str) -> Option { + // SCRAM-SHA-256$:$: + let s = input.strip_prefix("SCRAM-SHA-256$")?; + let (params, keys) = s.split_once('$')?; + + let ((iterations, salt), (stored_key, server_key)) = + params.split_once(':').zip(keys.split_once(':'))?; + + let secret = ServerSecret { + iterations: iterations.parse().ok()?, + salt_base64: salt.to_owned(), + stored_key: base64_decode_array(stored_key)?.into(), + server_key: base64_decode_array(server_key)?.into(), + }; + + Some(secret) + } + + /// To avoid revealing information to an attacker, we use a + /// mocked server secret even if the user doesn't exist. + /// See `auth-scram.c : mock_scram_secret` for details. + pub fn mock(user: &str, nonce: &[u8; 32]) -> Self { + // Refer to `auth-scram.c : scram_mock_salt`. + let mocked_salt = super::sha256([user.as_bytes(), nonce]); + + Self { + iterations: 4096, + salt_base64: base64::encode(&mocked_salt), + stored_key: ScramKey::default(), + server_key: ScramKey::default(), + } + } + + /// Build a new server secret from the prerequisites. + /// XXX: We only use this function in tests. + #[cfg(test)] + pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option { + // TODO: implement proper password normalization required by the RFC + if !password.is_ascii() { + return None; + } + + let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations); + + Some(Self { + iterations, + salt_base64: base64::encode(&salt), + stored_key: password.client_key().sha256(), + server_key: password.server_key(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_scram_secret() { + let iterations = 4096; + let salt = "+/tQQax7twvwTj64mjBsxQ=="; + let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns="; + let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI="; + + let secret = format!( + "SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}", + iterations = iterations, + salt = salt, + stored_key = stored_key, + server_key = server_key, + ); + + let parsed = ServerSecret::parse(&secret).unwrap(); + assert_eq!(parsed.iterations, iterations); + assert_eq!(parsed.salt_base64, salt); + + assert_eq!(base64::encode(parsed.stored_key), stored_key); + assert_eq!(base64::encode(parsed.server_key), server_key); + } + + #[test] + fn build_scram_secret() { + let salt = b"salt"; + let secret = ServerSecret::build("password", salt, 4096).unwrap(); + assert_eq!(secret.iterations, 4096); + assert_eq!(secret.salt_base64, base64::encode(salt)); + assert_eq!( + base64::encode(secret.stored_key.as_ref()), + "lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ=" + ); + assert_eq!( + base64::encode(secret.server_key.as_ref()), + "ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw=" + ); + } +} diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs new file mode 100644 index 0000000000..1c2811d757 --- /dev/null +++ b/proxy/src/scram/signature.rs @@ -0,0 +1,66 @@ +//! Tools for client/server signature management. + +use super::key::{ScramKey, SCRAM_KEY_LEN}; + +/// A collection of message parts needed to derive the client's signature. +#[derive(Debug)] +pub struct SignatureBuilder<'a> { + pub client_first_message_bare: &'a str, + pub server_first_message: &'a str, + pub client_final_message_without_proof: &'a str, +} + +impl SignatureBuilder<'_> { + pub fn build(&self, key: &ScramKey) -> Signature { + let parts = [ + self.client_first_message_bare.as_bytes(), + b",", + self.server_first_message.as_bytes(), + b",", + self.client_final_message_without_proof.as_bytes(), + ]; + + super::hmac_sha256(key.as_ref(), parts).into() + } +} + +/// A computed value which, when xored with `ClientProof`, +/// produces `ClientKey` that we need for authentication. +#[derive(Debug)] +#[repr(transparent)] +pub struct Signature { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl Signature { + /// Derive `ClientKey` from client's signature and proof. + pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { + // This is how the proof is calculated: + // + // 1. sha256(ClientKey) -> StoredKey + // 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature + // 3. ClientKey ^ ClientSignature -> ClientProof + // + // Step 3 implies that we can restore ClientKey from the proof + // by xoring the latter with the ClientSignature. Afterwards we + // can check that the presumed ClientKey meets our expectations. + let mut signature = self.bytes; + for (i, x) in proof.iter().enumerate() { + signature[i] ^= x; + } + + signature.into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for Signature { + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for Signature { + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 83792f2aca..f984fb4417 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -375,9 +375,8 @@ impl PostgresBackend { } AuthType::MD5 => { rand::thread_rng().fill(&mut self.md5_salt); - let md5_salt = self.md5_salt; self.write_message(&BeMessage::AuthenticationMD5Password( - &md5_salt, + self.md5_salt, ))?; self.state = ProtoState::Authentication; } diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index cb69418c07..403e176b14 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -401,7 +401,8 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { #[derive(Debug)] pub enum BeMessage<'a> { AuthenticationOk, - AuthenticationMD5Password(&'a [u8; 4]), + AuthenticationMD5Password([u8; 4]), + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), AuthenticationCleartextPassword, BackendKeyData(CancelKeyData), BindComplete, @@ -429,6 +430,13 @@ pub enum BeMessage<'a> { KeepAlive(WalSndKeepAlive), } +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + #[derive(Debug)] pub enum BeParameterStatusMessage<'a> { Encoding(&'a str), @@ -611,6 +619,32 @@ impl<'a> BeMessage<'a> { .unwrap(); // write into BytesMut can't fail } + BeMessage::AuthenticationSasl(msg) => { + buf.put_u8(b'R'); + write_body(buf, |buf| { + use BeAuthenticationSaslMessage::*; + match msg { + Methods(methods) => { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods.iter() { + write_cstr(method.as_bytes(), buf)?; + } + buf.put_u8(0); // zero terminator for the list + } + Continue(extra) => { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + } + Final(extra) => { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + } + } + Ok::<_, io::Error>(()) + }) + .unwrap() + } + BeMessage::BackendKeyData(key_data) => { buf.put_u8(b'K'); write_body(buf, |buf| {