From 607c0facfc26734e6a13def384c8a1167ea378e0 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Thu, 3 Nov 2022 18:07:16 +0300 Subject: [PATCH] [proxy] Propagate more console API errors to the user This patch aims to fix some of the inconsistencies in error reporting, for example "Internal error" or "Console request failed" instead of "password authentication failed for user ''". --- proxy/src/auth.rs | 10 +- proxy/src/auth/backend/console.rs | 308 ++++++++++++++++++-------- proxy/src/auth/backend/postgres.rs | 108 +++++---- proxy/src/auth/flow.rs | 6 +- proxy/src/error.rs | 20 +- proxy/src/http.rs | 10 - proxy/src/proxy.rs | 17 +- proxy/src/proxy/tests.rs | 13 +- proxy/src/sasl.rs | 20 +- proxy/src/sasl/stream.rs | 31 ++- proxy/src/scram/exchange.rs | 23 +- proxy/src/scram/secret.rs | 11 +- proxy/src/stream.rs | 4 +- test_runner/fixtures/neon_fixtures.py | 99 +++++---- test_runner/regress/test_proxy.py | 125 ++++++----- 15 files changed, 504 insertions(+), 301 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index f272f9adc1..5355946beb 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -49,6 +49,9 @@ pub enum AuthErrorImpl { )] MissingProjectName, + #[error("password authentication failed for user '{0}'")] + AuthFailed(Box), + /// Errors produced by e.g. [`crate::stream::PqStream`]. #[error(transparent)] Io(#[from] io::Error), @@ -62,6 +65,10 @@ impl AuthError { pub fn bad_auth_method(name: impl Into>) -> Self { AuthErrorImpl::BadAuthMethod(name.into()).into() } + + pub fn auth_failed(user: impl Into>) -> Self { + AuthErrorImpl::AuthFailed(user.into()).into() + } } impl> From for AuthError { @@ -78,10 +85,11 @@ impl UserFacingError for AuthError { GetAuthInfo(e) => e.to_string_client(), WakeCompute(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), + AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), MalformedPassword(_) => self.to_string(), MissingProjectName => self.to_string(), - _ => "Internal error".to_string(), + Io(_) => "Internal error".to_string(), } } } diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs index 929dfb33f7..040870fc8e 100644 --- a/proxy/src/auth/backend/console.rs +++ b/proxy/src/auth/backend/console.rs @@ -5,26 +5,74 @@ use crate::{ auth::{self, AuthFlow, ClientCredentials}, compute, error::{io_error, UserFacingError}, - http, scram, + http, sasl, scram, stream::PqStream, }; use futures::TryFutureExt; -use serde::{Deserialize, Serialize}; +use reqwest::StatusCode as HttpStatusCode; +use serde::Deserialize; use std::future::Future; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, info_span}; +use tracing::{error, info, info_span, warn, Instrument}; +/// A go-to error message which doesn't leak any detail. const REQUEST_FAILED: &str = "Console request failed"; +/// Common console API error. #[derive(Debug, Error)] -#[error("{}", REQUEST_FAILED)] -pub struct TransportError(#[from] std::io::Error); +pub enum ApiError { + /// Error returned by the console itself. + #[error("{REQUEST_FAILED} with {}: {}", .status, .text)] + Console { + status: HttpStatusCode, + text: Box, + }, -impl UserFacingError for TransportError {} + /// Various IO errors like broken pipe or malformed payload. + #[error("{REQUEST_FAILED}: {0}")] + Transport(#[from] std::io::Error), +} + +impl ApiError { + /// Returns HTTP status code if it's the reason for failure. + fn http_status_code(&self) -> Option { + use ApiError::*; + match self { + Console { status, .. } => Some(*status), + _ => None, + } + } +} + +impl UserFacingError for ApiError { + fn to_string_client(&self) -> String { + use ApiError::*; + match self { + // To minimize risks, only select errors are forwarded to users. + // Ask @neondatabase/control-plane for review before adding more. + Console { status, .. } => match *status { + HttpStatusCode::NOT_FOUND => { + // Status 404: failed to get a project-related resource. + format!("{REQUEST_FAILED}: endpoint cannot be found") + } + HttpStatusCode::NOT_ACCEPTABLE => { + // Status 406: endpoint is disabled (we don't allow connections). + format!("{REQUEST_FAILED}: endpoint is disabled") + } + HttpStatusCode::LOCKED => { + // Status 423: project might be in maintenance mode (or bad state). + format!("{REQUEST_FAILED}: endpoint is temporary unavailable") + } + _ => REQUEST_FAILED.to_owned(), + }, + _ => REQUEST_FAILED.to_owned(), + } + } +} // Helps eliminate graceless `.map_err` calls without introducing another ctor. -impl From for TransportError { +impl From for ApiError { fn from(e: reqwest::Error) -> Self { io_error(e).into() } @@ -37,61 +85,73 @@ pub enum GetAuthInfoError { BadSecret, #[error(transparent)] - Transport(TransportError), + ApiError(ApiError), +} + +// This allows more useful interactions than `#[from]`. +impl> From for GetAuthInfoError { + fn from(e: E) -> Self { + Self::ApiError(e.into()) + } } impl UserFacingError for GetAuthInfoError { fn to_string_client(&self) -> String { use GetAuthInfoError::*; match self { + // We absolutely should not leak any secrets! BadSecret => REQUEST_FAILED.to_owned(), - Transport(e) => e.to_string_client(), + // However, API might return a meaningful error. + ApiError(e) => e.to_string_client(), } } } -impl> From for GetAuthInfoError { - fn from(e: E) -> Self { - Self::Transport(e.into()) - } -} - #[derive(Debug, Error)] pub enum WakeComputeError { - // We shouldn't show users the address even if it's broken. #[error("Console responded with a malformed compute address: {0}")] - BadComputeAddress(String), + BadComputeAddress(Box), #[error(transparent)] - Transport(TransportError), + ApiError(ApiError), +} + +// This allows more useful interactions than `#[from]`. +impl> From for WakeComputeError { + fn from(e: E) -> Self { + Self::ApiError(e.into()) + } } impl UserFacingError for WakeComputeError { fn to_string_client(&self) -> String { use WakeComputeError::*; match self { + // We shouldn't show user the address even if it's broken. + // Besides, user is unlikely to care about this detail. BadComputeAddress(_) => REQUEST_FAILED.to_owned(), - Transport(e) => e.to_string_client(), + // However, API might return a meaningful error. + ApiError(e) => e.to_string_client(), } } } -impl> From for WakeComputeError { - fn from(e: E) -> Self { - Self::Transport(e.into()) - } +/// Console's response which holds client's auth secret. +#[derive(Deserialize, Debug)] +struct GetRoleSecret { + role_secret: Box, } -// TODO: convert into an enum with "error" -#[derive(Serialize, Deserialize, Debug)] -struct GetRoleSecretResponse { - role_secret: String, +/// Console's response which holds compute node's `host:port` pair. +#[derive(Deserialize, Debug)] +struct WakeCompute { + address: Box, } -// TODO: convert into an enum with "error" -#[derive(Serialize, Deserialize, Debug)] -struct GetWakeComputeResponse { - address: String, +/// Console's error response with human-readable description. +#[derive(Deserialize, Debug)] +struct ConsoleError { + error: Box, } /// Auth secret which is managed by the cloud. @@ -110,6 +170,12 @@ pub(super) struct Api<'a> { creds: &'a ClientCredentials<'a>, } +impl<'a> AsRef> for Api<'a> { + fn as_ref(&self) -> &ClientCredentials<'a> { + self.creds + } +} + impl<'a> Api<'a> { /// Construct an API object containing the auth parameters. pub(super) fn new( @@ -126,83 +192,88 @@ impl<'a> Api<'a> { /// Authenticate the existing user or throw an error. pub(super) async fn handle_user( - self, + &'a self, client: &mut PqStream, ) -> auth::Result> { - handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await + handle_user(client, self, Self::get_auth_info, Self::wake_compute).await } +} - async fn get_auth_info(&self) -> Result { +impl Api<'_> { + async fn get_auth_info(&self) -> Result, GetAuthInfoError> { let request_id = uuid::Uuid::new_v4().to_string(); - let req = self - .endpoint - .get("proxy_get_role_secret") - .header("X-Request-ID", &request_id) - .query(&[("session_id", self.extra.session_id)]) - .query(&[ - ("application_name", self.extra.application_name), - ("project", Some(self.creds.project().expect("impossible"))), - ("role", Some(self.creds.user)), - ]) - .build()?; + async { + let request = self + .endpoint + .get("proxy_get_role_secret") + .header("X-Request-ID", &request_id) + .query(&[("session_id", self.extra.session_id)]) + .query(&[ + ("application_name", self.extra.application_name), + ("project", Some(self.creds.project().expect("impossible"))), + ("role", Some(self.creds.user)), + ]) + .build()?; - let span = info_span!("http", id = request_id, url = req.url().as_str()); - info!(parent: &span, "request auth info"); - let msg = self - .endpoint - .checked_execute(req) - .and_then(|r| r.json::()) - .await - .map_err(|e| { - error!(parent: &span, "{e}"); - e - })?; + info!(url = request.url().as_str(), "sending http request"); + let response = self.endpoint.execute(request).await?; + let body = match parse_body::(response).await { + Ok(body) => body, + // Error 404 is special: it's ok not to have a secret. + Err(e) => match e.http_status_code() { + Some(HttpStatusCode::NOT_FOUND) => return Ok(None), + _otherwise => return Err(e.into()), + }, + }; - scram::ServerSecret::parse(&msg.role_secret) - .map(AuthInfo::Scram) - .ok_or(GetAuthInfoError::BadSecret) + let secret = scram::ServerSecret::parse(&body.role_secret) + .map(AuthInfo::Scram) + .ok_or(GetAuthInfoError::BadSecret)?; + + Ok(Some(secret)) + } + .map_err(crate::error::log_error) + .instrument(info_span!("get_auth_info", id = request_id)) + .await } /// Wake up the compute node and return the corresponding connection info. - pub(super) async fn wake_compute(&self) -> Result { + pub async fn wake_compute(&self) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); - let req = self - .endpoint - .get("proxy_wake_compute") - .header("X-Request-ID", &request_id) - .query(&[("session_id", self.extra.session_id)]) - .query(&[ - ("application_name", self.extra.application_name), - ("project", Some(self.creds.project().expect("impossible"))), - ]) - .build()?; + async { + let request = self + .endpoint + .get("proxy_wake_compute") + .header("X-Request-ID", &request_id) + .query(&[("session_id", self.extra.session_id)]) + .query(&[ + ("application_name", self.extra.application_name), + ("project", Some(self.creds.project().expect("impossible"))), + ]) + .build()?; - let span = info_span!("http", id = request_id, url = req.url().as_str()); - info!(parent: &span, "request wake-up"); - let msg = self - .endpoint - .checked_execute(req) - .and_then(|r| r.json::()) - .await - .map_err(|e| { - error!(parent: &span, "{e}"); - e - })?; + info!(url = request.url().as_str(), "sending http request"); + let response = self.endpoint.execute(request).await?; + let body = parse_body::(response).await?; - // Unfortunately, ownership won't let us use `Option::ok_or` here. - let (host, port) = match parse_host_port(&msg.address) { - None => return Err(WakeComputeError::BadComputeAddress(msg.address)), - Some(x) => x, - }; + // Unfortunately, ownership won't let us use `Option::ok_or` here. + let (host, port) = match parse_host_port(&body.address) { + None => return Err(WakeComputeError::BadComputeAddress(body.address)), + Some(x) => x, + }; - let mut config = compute::ConnCfg::new(); - config - .host(host) - .port(port) - .dbname(self.creds.dbname) - .user(self.creds.user); + let mut config = compute::ConnCfg::new(); + config + .host(host) + .port(port) + .dbname(self.creds.dbname) + .user(self.creds.user); - Ok(config) + Ok(config) + } + .map_err(crate::error::log_error) + .instrument(info_span!("wake_compute", id = request_id)) + .await } } @@ -215,24 +286,40 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>( wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute, ) -> auth::Result> where - GetAuthInfo: Future>, + Endpoint: AsRef>, + GetAuthInfo: Future, GetAuthInfoError>>, WakeCompute: Future>, { + let creds = endpoint.as_ref(); + info!("fetching user's authentication info"); - let auth_info = get_auth_info(endpoint).await?; + let info = get_auth_info(endpoint).await?.unwrap_or_else(|| { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random())) + }); let flow = AuthFlow::new(client); - let scram_keys = match auth_info { + let scram_keys = match info { AuthInfo::Md5(_) => { - // TODO: decide if we should support MD5 in api v2 info!("auth endpoint chooses MD5"); return Err(auth::AuthError::bad_auth_method("MD5")); } AuthInfo::Scram(secret) => { info!("auth endpoint chooses SCRAM"); let scram = auth::Scram(&secret); + let client_key = match flow.begin(scram).await?.authenticate().await? { + sasl::Outcome::Success(key) => key, + sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(creds.user)); + } + }; + Some(compute::ScramKeys { - client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), + client_key: client_key.as_bytes(), server_key: secret.server_key.as_bytes(), }) } @@ -249,6 +336,31 @@ where }) } +/// Parse http response body, taking status code into account. +async fn parse_body Deserialize<'a>>( + response: reqwest::Response, +) -> Result { + let status = response.status(); + if status.is_success() { + // We shouldn't log raw body because it may contain secrets. + info!("request succeeded, processing the body"); + return Ok(response.json().await?); + } + + // Don't throw an error here because it's not as important + // as the fact that the request itself has failed. + let body = response.json().await.unwrap_or_else(|e| { + warn!("failed to parse error body: {e}"); + ConsoleError { + error: "reason unclear (malformed error message)".into(), + } + }); + + let text = body.error; + error!("console responded with an error ({status}): {text}"); + Err(ApiError::Console { status, text }) +} + fn parse_host_port(input: &str) -> Option<(&str, u16)> { let (host, port) = input.split_once(':')?; Some((host, port.parse().ok()?)) diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/auth/backend/postgres.rs index e56b62622a..8f16dc9fa8 100644 --- a/proxy/src/auth/backend/postgres.rs +++ b/proxy/src/auth/backend/postgres.rs @@ -1,7 +1,7 @@ //! Local mock of Cloud API V2. use super::{ - console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError}, + console::{self, AuthInfo, GetAuthInfoError, WakeComputeError}, AuthSuccess, }; use crate::{ @@ -12,7 +12,28 @@ use crate::{ stream::PqStream, url::ApiUrl, }; +use futures::TryFutureExt; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{info, info_span, warn, Instrument}; + +#[derive(Debug, Error)] +enum MockApiError { + #[error("Failed to read password: {0}")] + PasswordNotSet(tokio_postgres::Error), +} + +impl From for console::ApiError { + fn from(e: MockApiError) -> Self { + io_error(e).into() + } +} + +impl From for console::ApiError { + fn from(e: tokio_postgres::Error) -> Self { + io_error(e).into() + } +} #[must_use] pub(super) struct Api<'a> { @@ -20,10 +41,9 @@ pub(super) struct Api<'a> { creds: &'a ClientCredentials<'a>, } -// Helps eliminate graceless `.map_err` calls without introducing another ctor. -impl From for TransportError { - fn from(e: tokio_postgres::Error) -> Self { - io_error(e).into() +impl<'a> AsRef> for Api<'a> { + fn as_ref(&self) -> &ClientCredentials<'a> { + self.creds } } @@ -35,54 +55,55 @@ impl<'a> Api<'a> { /// Authenticate the existing user or throw an error. pub(super) async fn handle_user( - self, + &'a self, client: &mut PqStream, ) -> auth::Result> { // We reuse user handling logic from a production module. - console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await + console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await } +} +impl Api<'_> { /// This implementation fetches the auth info from a local postgres instance. - async fn get_auth_info(&self) -> Result { - // Perhaps we could persist this connection, but then we'd have to - // write more code for reopening it if it got closed, which doesn't - // seem worth it. - let (client, connection) = - tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + async fn get_auth_info(&self) -> Result, GetAuthInfoError> { + async { + // Perhaps we could persist this connection, but then we'd have to + // write more code for reopening it if it got closed, which doesn't + // seem worth it. + let (client, connection) = + tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; - tokio::spawn(connection); - let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; - let rows = client.query(query, &[&self.creds.user]).await?; + tokio::spawn(connection); + let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; + let rows = client.query(query, &[&self.creds.user]).await?; - match &rows[..] { - // We can't get a secret if there's no such user. - [] => Err(io_error(format!("unknown user '{}'", self.creds.user)).into()), + // We can get at most one row, because `rolname` is unique. + let row = match rows.get(0) { + Some(row) => row, + // This means that the user doesn't exist, so there can be no secret. + // However, this is still a *valid* outcome which is very similar + // to getting `404 Not found` from the Neon console. + None => { + warn!("user '{}' does not exist", self.creds.user); + return Ok(None); + } + }; - // We shouldn't get more than one row anyway. - [row, ..] => { - let entry = row - .try_get("rolpassword") - .map_err(|e| io_error(format!("failed to read user's password: {e}")))?; + let entry = row + .try_get("rolpassword") + .map_err(MockApiError::PasswordNotSet)?; - scram::ServerSecret::parse(entry) - .map(AuthInfo::Scram) - .or_else(|| { - // It could be an md5 hash if it's not a SCRAM secret. - let text = entry.strip_prefix("md5")?; - Some(AuthInfo::Md5({ - let mut bytes = [0u8; 16]; - hex::decode_to_slice(text, &mut bytes).ok()?; - bytes - })) - }) - // Putting the secret into this message is a security hazard! - .ok_or(GetAuthInfoError::BadSecret) - } + info!("got a secret: {entry}"); // safe since it's not a prod scenario + let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram); + Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5))) } + .map_err(crate::error::log_error) + .instrument(info_span!("get_auth_info", mock = self.endpoint.as_str())) + .await } /// We don't need to wake anything locally, so we just return the connection info. - pub(super) async fn wake_compute(&self) -> Result { + pub async fn wake_compute(&self) -> Result { let mut config = compute::ConnCfg::new(); config .host(self.endpoint.host_str().unwrap_or("localhost")) @@ -93,3 +114,12 @@ impl<'a> Api<'a> { Ok(config) } } + +fn parse_md5(input: &str) -> Option<[u8; 16]> { + let text = input.strip_prefix("md5")?; + + let mut bytes = [0u8; 16]; + hex::decode_to_slice(text, &mut bytes).ok()?; + + Some(bytes) +} diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 865af4d2e5..d9ee50894d 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -89,7 +89,7 @@ impl AuthFlow<'_, S, PasswordHack> { /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result { + pub async fn authenticate(self) -> super::Result> { // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; let sasl = sasl::FirstMessage::parse(&msg) @@ -101,10 +101,10 @@ impl AuthFlow<'_, S, Scram<'_>> { } let secret = self.state.0; - let key = sasl::SaslStream::new(self.stream, sasl.message) + let outcome = sasl::SaslStream::new(self.stream, sasl.message) .authenticate(scram::Exchange::new(secret, rand::random, None)) .await?; - Ok(key) + Ok(outcome) } } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 0e376a37cd..f1cb44b1a8 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -1,4 +1,15 @@ -use std::io; +use std::{error::Error as StdError, fmt, io}; + +/// Upcast (almost) any error into an opaque [`io::Error`]. +pub fn io_error(e: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) +} + +/// A small combinator for pluggable error logging. +pub fn log_error(e: E) -> E { + tracing::error!("{e}"); + e +} /// Marks errors that may be safely shown to a client. /// This trait can be seen as a specialized version of [`ToString`]. @@ -6,7 +17,7 @@ use std::io; /// NOTE: This trait should not be implemented for [`anyhow::Error`], since it /// is way too convenient and tends to proliferate all across the codebase, /// ultimately leading to accidental leaks of sensitive data. -pub trait UserFacingError: ToString { +pub trait UserFacingError: fmt::Display { /// Format the error for client, stripping all sensitive info. /// /// Although this might be a no-op for many types, it's highly @@ -17,8 +28,3 @@ pub trait UserFacingError: ToString { self.to_string() } } - -/// Upcast (almost) any error into an opaque [`io::Error`]. -pub fn io_error(e: impl Into>) -> io::Error { - io::Error::new(io::ErrorKind::Other, e) -} diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 6f9145678b..096a33d73d 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -37,16 +37,6 @@ impl Endpoint { ) -> Result { self.client.execute(request).await } - - /// Execute a [request](reqwest::Request) and raise an error if status != 200. - pub async fn checked_execute( - &self, - request: reqwest::Request, - ) -> Result { - self.execute(request) - .await - .and_then(|r| r.error_for_status()) - } } #[cfg(test)] diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 411893fee5..da3cb144e3 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -49,17 +49,6 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { .unwrap() }); -/// A small combinator for pluggable error logging. -async fn log_error(future: F) -> F::Output -where - F: std::future::Future>, -{ - future.await.map_err(|err| { - error!("{err}"); - err - }) -} - pub async fn task_main( config: &'static ProxyConfig, listener: tokio::net::TcpListener, @@ -80,7 +69,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancel_map = Arc::clone(&cancel_map); tokio::spawn( - log_error(async move { + async move { info!("spawned a task for {peer_addr}"); socket @@ -88,6 +77,10 @@ pub async fn task_main( .context("failed to set socket option")?; handle_client(config, &cancel_map, session_id, socket).await + } + .unwrap_or_else(|e| { + // Acknowledge that the task has finished with an error. + error!("per-client task finished with an error: {e:#}"); }) .instrument(info_span!("client", session = format_args!("{session_id}"))), ); diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 3d74dbae5a..24fbc57b99 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -1,6 +1,6 @@ ///! A group of high-level tests for connection establishing logic and auth. use super::*; -use crate::{auth, scram}; +use crate::{auth, sasl, scram}; use async_trait::async_trait; use rstest::rstest; use tokio_postgres::config::SslMode; @@ -100,8 +100,7 @@ impl Scram { } fn mock(user: &str) -> Self { - let salt = rand::random::<[u8; 32]>(); - Scram(scram::ServerSecret::mock(user, &salt)) + Scram(scram::ServerSecret::mock(user, rand::random())) } } @@ -111,13 +110,17 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - auth::AuthFlow::new(stream) + let outcome = auth::AuthFlow::new(stream) .begin(auth::Scram(&self.0)) .await? .authenticate() .await?; - Ok(()) + use sasl::Outcome::*; + match outcome { + Success(_) => Ok(()), + Failure(reason) => bail!("autentication failed with an error: {reason}"), + } } } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index 689fca6049..6d1dd9fba5 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -16,22 +16,19 @@ use thiserror::Error; pub use channel_binding::ChannelBinding; pub use messages::FirstMessage; -pub use stream::SaslStream; +pub use stream::{Outcome, SaslStream}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub enum Error { - #[error("Failed to authenticate client: {0}")] - AuthenticationFailed(&'static str), - #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), #[error("Unsupported channel binding method: {0}")] ChannelBindingBadMethod(Box), - #[error("Bad client message")] - BadClientMessage, + #[error("Bad client message: {0}")] + BadClientMessage(&'static str), #[error(transparent)] Io(#[from] io::Error), @@ -41,8 +38,6 @@ impl UserFacingError for Error { fn to_string_client(&self) -> String { use Error::*; match self { - // This constructor contains the reason why auth has failed. - AuthenticationFailed(s) => s.to_string(), // TODO: add support for channel binding ChannelBindingFailed(_) => "channel binding is not supported yet".to_string(), ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"), @@ -55,11 +50,14 @@ impl UserFacingError for Error { pub type Result = std::result::Result; /// A result of one SASL exchange. +#[must_use] pub enum Step { /// We should continue exchanging messages. - Continue(T), + Continue(T, String), /// The client has been authenticated successfully. - Authenticated(R), + Success(R, String), + /// Authentication failed (reason attached). + Failure(&'static str), } /// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. @@ -69,5 +67,5 @@ pub trait Mechanism: Sized { /// Produce a server challenge to be sent to the client. /// This is how this method is called in PostgreSQL (`libpq/sasl.h`). - fn exchange(self, input: &str) -> Result<(Step, String)>; + fn exchange(self, input: &str) -> Result>; } diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 0e782c5f29..b24cc4bf44 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -48,28 +48,41 @@ impl SaslStream<'_, S> { } } +/// SASL authentication outcome. +/// It's much easier to match on those two variants +/// than to peek into a noisy protocol error type. +#[must_use = "caller must explicitly check for success"] +pub enum Outcome { + /// Authentication succeeded and produced some value. + Success(R), + /// Authentication failed (reason attached). + Failure(&'static str), +} + impl SaslStream<'_, S> { /// Perform SASL message exchange according to the underlying algorithm /// until user is either authenticated or denied access. pub async fn authenticate( mut self, mut mechanism: M, - ) -> super::Result { + ) -> super::Result> { loop { let input = self.recv().await?; - let (moved, reply) = mechanism.exchange(input)?; + let step = mechanism.exchange(input)?; - use super::Step::*; - match moved { - Continue(moved) => { + use super::Step; + return Ok(match step { + Step::Continue(moved_mechanism, reply) => { self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved; + mechanism = moved_mechanism; + continue; } - Authenticated(result) => { + Step::Success(result, reply) => { self.send(&ServerMessage::Final(&reply)).await?; - return Ok(result); + Outcome::Success(result) } - } + Step::Failure(reason) => Outcome::Failure(reason), + }); } } } diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index fca5585b25..882769a70d 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -64,12 +64,12 @@ impl<'a> Exchange<'a> { impl sasl::Mechanism for Exchange<'_> { type Output = super::ScramKey; - fn exchange(mut self, input: &str) -> sasl::Result<(sasl::Step, String)> { + fn exchange(mut self, input: &str) -> sasl::Result> { use {sasl::Step::*, ExchangeState::*}; match &self.state { Initial => { - let client_first_message = - ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + let client_first_message = ClientFirstMessage::parse(input) + .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?; let server_first_message = client_first_message.build_server_first_message( &(self.nonce)(), @@ -84,15 +84,15 @@ impl sasl::Mechanism for Exchange<'_> { server_first_message, }; - Ok((Continue(self), msg)) + Ok(Continue(self, msg)) } SaltSent { cbind_flag, client_first_message_bare, server_first_message, } => { - let client_final_message = - ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + let client_final_message = ClientFinalMessage::parse(input) + .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?; let channel_binding = cbind_flag.encode(|_| { self.cert_digest @@ -106,9 +106,7 @@ impl sasl::Mechanism for Exchange<'_> { } if client_final_message.nonce != server_first_message.nonce() { - return Err(SaslError::AuthenticationFailed( - "combined nonce doesn't match", - )); + return Err(SaslError::BadClientMessage("combined nonce doesn't match")); } let signature_builder = SignatureBuilder { @@ -121,14 +119,15 @@ impl sasl::Mechanism for Exchange<'_> { .build(&self.secret.stored_key) .derive_client_key(&client_final_message.proof); - if client_key.sha256() != self.secret.stored_key { - return Err(SaslError::AuthenticationFailed("password doesn't match")); + // Auth fails either if keys don't match or it's pre-determined to fail. + if client_key.sha256() != self.secret.stored_key || self.secret.doomed { + return Ok(Failure("password doesn't match")); } let msg = client_final_message .build_server_final_message(signature_builder, &self.secret.server_key); - Ok((Authenticated(client_key), msg)) + Ok(Success(client_key, msg)) } } } diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 765aef4443..89668465fa 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -14,6 +14,9 @@ pub struct ServerSecret { pub stored_key: ScramKey, /// Used by client to verify server's signature. pub server_key: ScramKey, + /// Should auth fail no matter what? + /// This is exactly the case for mocked secrets. + pub doomed: bool, } impl ServerSecret { @@ -30,6 +33,7 @@ impl ServerSecret { salt_base64: salt.to_owned(), stored_key: base64_decode_array(stored_key)?.into(), server_key: base64_decode_array(server_key)?.into(), + doomed: false, }; Some(secret) @@ -38,16 +42,16 @@ impl ServerSecret { /// To avoid revealing information to an attacker, we use a /// mocked server secret even if the user doesn't exist. /// See `auth-scram.c : mock_scram_secret` for details. - #[allow(dead_code)] - pub fn mock(user: &str, nonce: &[u8; 32]) -> Self { + pub fn mock(user: &str, nonce: [u8; 32]) -> Self { // Refer to `auth-scram.c : scram_mock_salt`. - let mocked_salt = super::sha256([user.as_bytes(), nonce]); + let mocked_salt = super::sha256([user.as_bytes(), &nonce]); Self { iterations: 4096, salt_base64: base64::encode(&mocked_salt), stored_key: ScramKey::default(), server_key: ScramKey::default(), + doomed: true, } } @@ -67,6 +71,7 @@ impl ServerSecret { salt_base64: base64::encode(&salt), stored_key: password.client_key().sha256(), server_key: password.server_key(), + doomed: false, }) } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 8e4084775c..19e1479068 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -109,8 +109,9 @@ impl PqStream { /// Write the error message using [`Self::write_message`], then re-throw it. /// Allowing string literals is safe under the assumption they might not contain any runtime info. + /// This method exists due to `&str` not implementing `Into`. pub async fn throw_error_str(&mut self, error: &'static str) -> anyhow::Result { - // This method exists due to `&str` not implementing `Into` + tracing::info!("forwarding error to user: {error}"); self.write_message(&BeMessage::ErrorResponse(error)).await?; bail!(error) } @@ -122,6 +123,7 @@ impl PqStream { E: UserFacingError + Into, { let msg = error.to_string_client(); + tracing::info!("forwarding error to user: {msg}"); self.write_message(&BeMessage::ErrorResponse(&msg)).await?; bail!(error) } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 0d64ca6d65..e3f8247274 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2092,62 +2092,73 @@ class PSQL: class NeonProxy(PgProtocol): + link_auth_uri: str = "http://dummy-uri" + + class AuthBackend(abc.ABC): + """All auth backends must inherit from this class""" + + @property + def default_conn_url(self) -> Optional[str]: + return None + + @abc.abstractmethod + def extra_args(self) -> list[str]: + pass + + class Link(AuthBackend): + def extra_args(self) -> list[str]: + return [ + # Link auth backend params + *["--auth-backend", "link"], + *["--uri", NeonProxy.link_auth_uri], + ] + + @dataclass(frozen=True) + class Postgres(AuthBackend): + pg_conn_url: str + + @property + def default_conn_url(self) -> Optional[str]: + return self.pg_conn_url + + def extra_args(self) -> list[str]: + return [ + # Postgres auth backend params + *["--auth-backend", "postgres"], + *["--auth-endpoint", self.pg_conn_url], + ] + def __init__( self, + neon_binpath: Path, proxy_port: int, http_port: int, mgmt_port: int, - neon_binpath: Path, - auth_endpoint=None, + auth_backend: NeonProxy.AuthBackend, ): - super().__init__(dsn=auth_endpoint, port=proxy_port) - self.host = "127.0.0.1" + host = "127.0.0.1" + super().__init__(dsn=auth_backend.default_conn_url, host=host, port=proxy_port) + + self.host = host self.http_port = http_port self.neon_binpath = neon_binpath self.proxy_port = proxy_port self.mgmt_port = mgmt_port - self.auth_endpoint = auth_endpoint + self.auth_backend = auth_backend self._popen: Optional[subprocess.Popen[bytes]] = None - self.link_auth_uri_prefix = "http://dummy-uri" - def start(self): - """ - Starts a proxy with option '--auth-backend postgres' and a postgres instance - already provided though '--auth-endpoint '." - """ + def start(self) -> NeonProxy: assert self._popen is None - assert self.auth_endpoint is not None - - # Start proxy args = [ str(self.neon_binpath / "proxy"), *["--http", f"{self.host}:{self.http_port}"], *["--proxy", f"{self.host}:{self.proxy_port}"], *["--mgmt", f"{self.host}:{self.mgmt_port}"], - *["--auth-backend", "postgres"], - *["--auth-endpoint", self.auth_endpoint], + *self.auth_backend.extra_args(), ] self._popen = subprocess.Popen(args) self._wait_until_ready() - - def start_with_link_auth(self): - """ - Starts a proxy with option '--auth-backend link' and a dummy authentication link '--uri dummy-auth-link'." - """ - assert self._popen is None - - # Start proxy - bin_proxy = str(self.neon_binpath / "proxy") - args = [bin_proxy] - args.extend(["--http", f"{self.host}:{self.http_port}"]) - args.extend(["--proxy", f"{self.host}:{self.proxy_port}"]) - args.extend(["--mgmt", f"{self.host}:{self.mgmt_port}"]) - args.extend(["--auth-backend", "link"]) - args.extend(["--uri", self.link_auth_uri_prefix]) - arg_str = " ".join(args) - log.info(f"starting proxy with command line ::: {arg_str}") - self._popen = subprocess.Popen(args, stdout=subprocess.PIPE) - self._wait_until_ready() + return self @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) def _wait_until_ready(self): @@ -2158,7 +2169,7 @@ class NeonProxy(PgProtocol): request_result.raise_for_status() return request_result.text - def __enter__(self) -> "NeonProxy": + def __enter__(self) -> NeonProxy: return self def __exit__( @@ -2176,11 +2187,19 @@ class NeonProxy(PgProtocol): @pytest.fixture(scope="function") def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterator[NeonProxy]: """Neon proxy that routes through link auth.""" + http_port = port_distributor.get_port() proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() - with NeonProxy(proxy_port, http_port, neon_binpath=neon_binpath, mgmt_port=mgmt_port) as proxy: - proxy.start_with_link_auth() + + with NeonProxy( + neon_binpath=neon_binpath, + proxy_port=proxy_port, + http_port=http_port, + mgmt_port=mgmt_port, + auth_backend=NeonProxy.Link(), + ) as proxy: + proxy.start() yield proxy @@ -2204,11 +2223,11 @@ def static_proxy( http_port = port_distributor.get_port() with NeonProxy( + neon_binpath=neon_binpath, proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, - neon_binpath=neon_binpath, - auth_endpoint=auth_endpoint, + auth_backend=NeonProxy.Postgres(auth_endpoint), ) as proxy: proxy.start() yield proxy diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index e868d6b616..eab9505fbb 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -28,61 +28,58 @@ def test_password_hack(static_proxy: NeonProxy): static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic) -def get_session_id(uri_prefix, uri_line): - assert uri_prefix in uri_line - - url_parts = urlparse(uri_line) - psql_session_id = url_parts.path[1:] - assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars" - - return psql_session_id - - -async def find_auth_link(link_auth_uri_prefix, proc): - for _ in range(100): - line = (await proc.stderr.readline()).decode("utf-8").strip() - log.info(f"psql line: {line}") - if link_auth_uri_prefix in line: - log.info(f"SUCCESS, found auth url: {line}") - return line - - -async def activate_link_auth(local_vanilla_pg, link_proxy, psql_session_id): - pg_user = "proxy" - - log.info("creating a new user for link auth test") - local_vanilla_pg.start() - local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") - - db_info = json.dumps( - { - "session_id": psql_session_id, - "result": { - "Success": { - "host": local_vanilla_pg.default_options["host"], - "port": local_vanilla_pg.default_options["port"], - "dbname": local_vanilla_pg.default_options["dbname"], - "user": pg_user, - "project": "irrelevant", - } - }, - } - ) - - log.info("sending session activation message") - psql = await PSQL(host=link_proxy.host, port=link_proxy.mgmt_port).run(db_info) - out = (await psql.stdout.read()).decode("utf-8").strip() - assert out == "ok" - - @pytest.mark.asyncio async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): + def get_session_id(uri_prefix, uri_line): + assert uri_prefix in uri_line + + url_parts = urlparse(uri_line) + psql_session_id = url_parts.path[1:] + assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars" + + return psql_session_id + + async def find_auth_link(link_auth_uri, proc): + for _ in range(100): + line = (await proc.stderr.readline()).decode("utf-8").strip() + log.info(f"psql line: {line}") + if link_auth_uri in line: + log.info(f"SUCCESS, found auth url: {line}") + return line + + async def activate_link_auth(local_vanilla_pg, link_proxy, psql_session_id): + pg_user = "proxy" + + log.info("creating a new user for link auth test") + local_vanilla_pg.start() + local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") + + db_info = json.dumps( + { + "session_id": psql_session_id, + "result": { + "Success": { + "host": local_vanilla_pg.default_options["host"], + "port": local_vanilla_pg.default_options["port"], + "dbname": local_vanilla_pg.default_options["dbname"], + "user": pg_user, + "project": "irrelevant", + } + }, + } + ) + + log.info("sending session activation message") + psql = await PSQL(host=link_proxy.host, port=link_proxy.mgmt_port).run(db_info) + out = (await psql.stdout.read()).decode("utf-8").strip() + assert out == "ok" + psql = await PSQL(host=link_proxy.host, port=link_proxy.proxy_port).run("select 42") - uri_prefix = link_proxy.link_auth_uri_prefix - link = await find_auth_link(uri_prefix, psql) + base_uri = link_proxy.link_auth_uri + link = await find_auth_link(base_uri, psql) - psql_session_id = get_session_id(uri_prefix, link) + psql_session_id = get_session_id(base_uri, link) await activate_link_auth(vanilla_pg, link_proxy, psql_session_id) assert psql.stdout is not None @@ -97,3 +94,31 @@ def test_proxy_options(static_proxy: NeonProxy): cur.execute("SHOW proxytest.option") value = cur.fetchall()[0][0] assert value == "value" + + +def test_auth_errors(static_proxy: NeonProxy): + # User does not exist + with pytest.raises(psycopg2.Error) as exprinfo: + static_proxy.connect(user="pinocchio", options="project=irrelevant") + text = str(exprinfo.value).strip() + assert text.endswith("password authentication failed for user 'pinocchio'") + + static_proxy.safe_psql( + "create role pinocchio with login password 'magic'", options="project=irrelevant" + ) + + # User exists, but password is missing + with pytest.raises(psycopg2.Error) as exprinfo: + static_proxy.connect(user="pinocchio", password=None, options="project=irrelevant") + text = str(exprinfo.value).strip() + assert text.endswith("password authentication failed for user 'pinocchio'") + + # User exists, but password is wrong + with pytest.raises(psycopg2.Error) as exprinfo: + static_proxy.connect(user="pinocchio", password="bad", options="project=irrelevant") + text = str(exprinfo.value).strip() + assert text.endswith("password authentication failed for user 'pinocchio'") + + # Finally, check that the user can connect + with static_proxy.connect(user="pinocchio", password="magic", options="project=irrelevant"): + pass