diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index f272f9adc1..5355946beb 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -49,6 +49,9 @@ pub enum AuthErrorImpl { )] MissingProjectName, + #[error("password authentication failed for user '{0}'")] + AuthFailed(Box), + /// Errors produced by e.g. [`crate::stream::PqStream`]. #[error(transparent)] Io(#[from] io::Error), @@ -62,6 +65,10 @@ impl AuthError { pub fn bad_auth_method(name: impl Into>) -> Self { AuthErrorImpl::BadAuthMethod(name.into()).into() } + + pub fn auth_failed(user: impl Into>) -> Self { + AuthErrorImpl::AuthFailed(user.into()).into() + } } impl> From for AuthError { @@ -78,10 +85,11 @@ impl UserFacingError for AuthError { GetAuthInfo(e) => e.to_string_client(), WakeCompute(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), + AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), MalformedPassword(_) => self.to_string(), MissingProjectName => self.to_string(), - _ => "Internal error".to_string(), + Io(_) => "Internal error".to_string(), } } } diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs index 929dfb33f7..040870fc8e 100644 --- a/proxy/src/auth/backend/console.rs +++ b/proxy/src/auth/backend/console.rs @@ -5,26 +5,74 @@ use crate::{ auth::{self, AuthFlow, ClientCredentials}, compute, error::{io_error, UserFacingError}, - http, scram, + http, sasl, scram, stream::PqStream, }; use futures::TryFutureExt; -use serde::{Deserialize, Serialize}; +use reqwest::StatusCode as HttpStatusCode; +use serde::Deserialize; use std::future::Future; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, info_span}; +use tracing::{error, info, info_span, warn, Instrument}; +/// A go-to error message which doesn't leak any detail. const REQUEST_FAILED: &str = "Console request failed"; +/// Common console API error. #[derive(Debug, Error)] -#[error("{}", REQUEST_FAILED)] -pub struct TransportError(#[from] std::io::Error); +pub enum ApiError { + /// Error returned by the console itself. + #[error("{REQUEST_FAILED} with {}: {}", .status, .text)] + Console { + status: HttpStatusCode, + text: Box, + }, -impl UserFacingError for TransportError {} + /// Various IO errors like broken pipe or malformed payload. + #[error("{REQUEST_FAILED}: {0}")] + Transport(#[from] std::io::Error), +} + +impl ApiError { + /// Returns HTTP status code if it's the reason for failure. + fn http_status_code(&self) -> Option { + use ApiError::*; + match self { + Console { status, .. } => Some(*status), + _ => None, + } + } +} + +impl UserFacingError for ApiError { + fn to_string_client(&self) -> String { + use ApiError::*; + match self { + // To minimize risks, only select errors are forwarded to users. + // Ask @neondatabase/control-plane for review before adding more. + Console { status, .. } => match *status { + HttpStatusCode::NOT_FOUND => { + // Status 404: failed to get a project-related resource. + format!("{REQUEST_FAILED}: endpoint cannot be found") + } + HttpStatusCode::NOT_ACCEPTABLE => { + // Status 406: endpoint is disabled (we don't allow connections). + format!("{REQUEST_FAILED}: endpoint is disabled") + } + HttpStatusCode::LOCKED => { + // Status 423: project might be in maintenance mode (or bad state). + format!("{REQUEST_FAILED}: endpoint is temporary unavailable") + } + _ => REQUEST_FAILED.to_owned(), + }, + _ => REQUEST_FAILED.to_owned(), + } + } +} // Helps eliminate graceless `.map_err` calls without introducing another ctor. -impl From for TransportError { +impl From for ApiError { fn from(e: reqwest::Error) -> Self { io_error(e).into() } @@ -37,61 +85,73 @@ pub enum GetAuthInfoError { BadSecret, #[error(transparent)] - Transport(TransportError), + ApiError(ApiError), +} + +// This allows more useful interactions than `#[from]`. +impl> From for GetAuthInfoError { + fn from(e: E) -> Self { + Self::ApiError(e.into()) + } } impl UserFacingError for GetAuthInfoError { fn to_string_client(&self) -> String { use GetAuthInfoError::*; match self { + // We absolutely should not leak any secrets! BadSecret => REQUEST_FAILED.to_owned(), - Transport(e) => e.to_string_client(), + // However, API might return a meaningful error. + ApiError(e) => e.to_string_client(), } } } -impl> From for GetAuthInfoError { - fn from(e: E) -> Self { - Self::Transport(e.into()) - } -} - #[derive(Debug, Error)] pub enum WakeComputeError { - // We shouldn't show users the address even if it's broken. #[error("Console responded with a malformed compute address: {0}")] - BadComputeAddress(String), + BadComputeAddress(Box), #[error(transparent)] - Transport(TransportError), + ApiError(ApiError), +} + +// This allows more useful interactions than `#[from]`. +impl> From for WakeComputeError { + fn from(e: E) -> Self { + Self::ApiError(e.into()) + } } impl UserFacingError for WakeComputeError { fn to_string_client(&self) -> String { use WakeComputeError::*; match self { + // We shouldn't show user the address even if it's broken. + // Besides, user is unlikely to care about this detail. BadComputeAddress(_) => REQUEST_FAILED.to_owned(), - Transport(e) => e.to_string_client(), + // However, API might return a meaningful error. + ApiError(e) => e.to_string_client(), } } } -impl> From for WakeComputeError { - fn from(e: E) -> Self { - Self::Transport(e.into()) - } +/// Console's response which holds client's auth secret. +#[derive(Deserialize, Debug)] +struct GetRoleSecret { + role_secret: Box, } -// TODO: convert into an enum with "error" -#[derive(Serialize, Deserialize, Debug)] -struct GetRoleSecretResponse { - role_secret: String, +/// Console's response which holds compute node's `host:port` pair. +#[derive(Deserialize, Debug)] +struct WakeCompute { + address: Box, } -// TODO: convert into an enum with "error" -#[derive(Serialize, Deserialize, Debug)] -struct GetWakeComputeResponse { - address: String, +/// Console's error response with human-readable description. +#[derive(Deserialize, Debug)] +struct ConsoleError { + error: Box, } /// Auth secret which is managed by the cloud. @@ -110,6 +170,12 @@ pub(super) struct Api<'a> { creds: &'a ClientCredentials<'a>, } +impl<'a> AsRef> for Api<'a> { + fn as_ref(&self) -> &ClientCredentials<'a> { + self.creds + } +} + impl<'a> Api<'a> { /// Construct an API object containing the auth parameters. pub(super) fn new( @@ -126,83 +192,88 @@ impl<'a> Api<'a> { /// Authenticate the existing user or throw an error. pub(super) async fn handle_user( - self, + &'a self, client: &mut PqStream, ) -> auth::Result> { - handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await + handle_user(client, self, Self::get_auth_info, Self::wake_compute).await } +} - async fn get_auth_info(&self) -> Result { +impl Api<'_> { + async fn get_auth_info(&self) -> Result, GetAuthInfoError> { let request_id = uuid::Uuid::new_v4().to_string(); - let req = self - .endpoint - .get("proxy_get_role_secret") - .header("X-Request-ID", &request_id) - .query(&[("session_id", self.extra.session_id)]) - .query(&[ - ("application_name", self.extra.application_name), - ("project", Some(self.creds.project().expect("impossible"))), - ("role", Some(self.creds.user)), - ]) - .build()?; + async { + let request = self + .endpoint + .get("proxy_get_role_secret") + .header("X-Request-ID", &request_id) + .query(&[("session_id", self.extra.session_id)]) + .query(&[ + ("application_name", self.extra.application_name), + ("project", Some(self.creds.project().expect("impossible"))), + ("role", Some(self.creds.user)), + ]) + .build()?; - let span = info_span!("http", id = request_id, url = req.url().as_str()); - info!(parent: &span, "request auth info"); - let msg = self - .endpoint - .checked_execute(req) - .and_then(|r| r.json::()) - .await - .map_err(|e| { - error!(parent: &span, "{e}"); - e - })?; + info!(url = request.url().as_str(), "sending http request"); + let response = self.endpoint.execute(request).await?; + let body = match parse_body::(response).await { + Ok(body) => body, + // Error 404 is special: it's ok not to have a secret. + Err(e) => match e.http_status_code() { + Some(HttpStatusCode::NOT_FOUND) => return Ok(None), + _otherwise => return Err(e.into()), + }, + }; - scram::ServerSecret::parse(&msg.role_secret) - .map(AuthInfo::Scram) - .ok_or(GetAuthInfoError::BadSecret) + let secret = scram::ServerSecret::parse(&body.role_secret) + .map(AuthInfo::Scram) + .ok_or(GetAuthInfoError::BadSecret)?; + + Ok(Some(secret)) + } + .map_err(crate::error::log_error) + .instrument(info_span!("get_auth_info", id = request_id)) + .await } /// Wake up the compute node and return the corresponding connection info. - pub(super) async fn wake_compute(&self) -> Result { + pub async fn wake_compute(&self) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); - let req = self - .endpoint - .get("proxy_wake_compute") - .header("X-Request-ID", &request_id) - .query(&[("session_id", self.extra.session_id)]) - .query(&[ - ("application_name", self.extra.application_name), - ("project", Some(self.creds.project().expect("impossible"))), - ]) - .build()?; + async { + let request = self + .endpoint + .get("proxy_wake_compute") + .header("X-Request-ID", &request_id) + .query(&[("session_id", self.extra.session_id)]) + .query(&[ + ("application_name", self.extra.application_name), + ("project", Some(self.creds.project().expect("impossible"))), + ]) + .build()?; - let span = info_span!("http", id = request_id, url = req.url().as_str()); - info!(parent: &span, "request wake-up"); - let msg = self - .endpoint - .checked_execute(req) - .and_then(|r| r.json::()) - .await - .map_err(|e| { - error!(parent: &span, "{e}"); - e - })?; + info!(url = request.url().as_str(), "sending http request"); + let response = self.endpoint.execute(request).await?; + let body = parse_body::(response).await?; - // Unfortunately, ownership won't let us use `Option::ok_or` here. - let (host, port) = match parse_host_port(&msg.address) { - None => return Err(WakeComputeError::BadComputeAddress(msg.address)), - Some(x) => x, - }; + // Unfortunately, ownership won't let us use `Option::ok_or` here. + let (host, port) = match parse_host_port(&body.address) { + None => return Err(WakeComputeError::BadComputeAddress(body.address)), + Some(x) => x, + }; - let mut config = compute::ConnCfg::new(); - config - .host(host) - .port(port) - .dbname(self.creds.dbname) - .user(self.creds.user); + let mut config = compute::ConnCfg::new(); + config + .host(host) + .port(port) + .dbname(self.creds.dbname) + .user(self.creds.user); - Ok(config) + Ok(config) + } + .map_err(crate::error::log_error) + .instrument(info_span!("wake_compute", id = request_id)) + .await } } @@ -215,24 +286,40 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>( wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute, ) -> auth::Result> where - GetAuthInfo: Future>, + Endpoint: AsRef>, + GetAuthInfo: Future, GetAuthInfoError>>, WakeCompute: Future>, { + let creds = endpoint.as_ref(); + info!("fetching user's authentication info"); - let auth_info = get_auth_info(endpoint).await?; + let info = get_auth_info(endpoint).await?.unwrap_or_else(|| { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random())) + }); let flow = AuthFlow::new(client); - let scram_keys = match auth_info { + let scram_keys = match info { AuthInfo::Md5(_) => { - // TODO: decide if we should support MD5 in api v2 info!("auth endpoint chooses MD5"); return Err(auth::AuthError::bad_auth_method("MD5")); } AuthInfo::Scram(secret) => { info!("auth endpoint chooses SCRAM"); let scram = auth::Scram(&secret); + let client_key = match flow.begin(scram).await?.authenticate().await? { + sasl::Outcome::Success(key) => key, + sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(creds.user)); + } + }; + Some(compute::ScramKeys { - client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), + client_key: client_key.as_bytes(), server_key: secret.server_key.as_bytes(), }) } @@ -249,6 +336,31 @@ where }) } +/// Parse http response body, taking status code into account. +async fn parse_body Deserialize<'a>>( + response: reqwest::Response, +) -> Result { + let status = response.status(); + if status.is_success() { + // We shouldn't log raw body because it may contain secrets. + info!("request succeeded, processing the body"); + return Ok(response.json().await?); + } + + // Don't throw an error here because it's not as important + // as the fact that the request itself has failed. + let body = response.json().await.unwrap_or_else(|e| { + warn!("failed to parse error body: {e}"); + ConsoleError { + error: "reason unclear (malformed error message)".into(), + } + }); + + let text = body.error; + error!("console responded with an error ({status}): {text}"); + Err(ApiError::Console { status, text }) +} + fn parse_host_port(input: &str) -> Option<(&str, u16)> { let (host, port) = input.split_once(':')?; Some((host, port.parse().ok()?)) diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/auth/backend/postgres.rs index e56b62622a..8f16dc9fa8 100644 --- a/proxy/src/auth/backend/postgres.rs +++ b/proxy/src/auth/backend/postgres.rs @@ -1,7 +1,7 @@ //! Local mock of Cloud API V2. use super::{ - console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError}, + console::{self, AuthInfo, GetAuthInfoError, WakeComputeError}, AuthSuccess, }; use crate::{ @@ -12,7 +12,28 @@ use crate::{ stream::PqStream, url::ApiUrl, }; +use futures::TryFutureExt; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{info, info_span, warn, Instrument}; + +#[derive(Debug, Error)] +enum MockApiError { + #[error("Failed to read password: {0}")] + PasswordNotSet(tokio_postgres::Error), +} + +impl From for console::ApiError { + fn from(e: MockApiError) -> Self { + io_error(e).into() + } +} + +impl From for console::ApiError { + fn from(e: tokio_postgres::Error) -> Self { + io_error(e).into() + } +} #[must_use] pub(super) struct Api<'a> { @@ -20,10 +41,9 @@ pub(super) struct Api<'a> { creds: &'a ClientCredentials<'a>, } -// Helps eliminate graceless `.map_err` calls without introducing another ctor. -impl From for TransportError { - fn from(e: tokio_postgres::Error) -> Self { - io_error(e).into() +impl<'a> AsRef> for Api<'a> { + fn as_ref(&self) -> &ClientCredentials<'a> { + self.creds } } @@ -35,54 +55,55 @@ impl<'a> Api<'a> { /// Authenticate the existing user or throw an error. pub(super) async fn handle_user( - self, + &'a self, client: &mut PqStream, ) -> auth::Result> { // We reuse user handling logic from a production module. - console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await + console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await } +} +impl Api<'_> { /// This implementation fetches the auth info from a local postgres instance. - async fn get_auth_info(&self) -> Result { - // Perhaps we could persist this connection, but then we'd have to - // write more code for reopening it if it got closed, which doesn't - // seem worth it. - let (client, connection) = - tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + async fn get_auth_info(&self) -> Result, GetAuthInfoError> { + async { + // Perhaps we could persist this connection, but then we'd have to + // write more code for reopening it if it got closed, which doesn't + // seem worth it. + let (client, connection) = + tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; - tokio::spawn(connection); - let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; - let rows = client.query(query, &[&self.creds.user]).await?; + tokio::spawn(connection); + let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; + let rows = client.query(query, &[&self.creds.user]).await?; - match &rows[..] { - // We can't get a secret if there's no such user. - [] => Err(io_error(format!("unknown user '{}'", self.creds.user)).into()), + // We can get at most one row, because `rolname` is unique. + let row = match rows.get(0) { + Some(row) => row, + // This means that the user doesn't exist, so there can be no secret. + // However, this is still a *valid* outcome which is very similar + // to getting `404 Not found` from the Neon console. + None => { + warn!("user '{}' does not exist", self.creds.user); + return Ok(None); + } + }; - // We shouldn't get more than one row anyway. - [row, ..] => { - let entry = row - .try_get("rolpassword") - .map_err(|e| io_error(format!("failed to read user's password: {e}")))?; + let entry = row + .try_get("rolpassword") + .map_err(MockApiError::PasswordNotSet)?; - scram::ServerSecret::parse(entry) - .map(AuthInfo::Scram) - .or_else(|| { - // It could be an md5 hash if it's not a SCRAM secret. - let text = entry.strip_prefix("md5")?; - Some(AuthInfo::Md5({ - let mut bytes = [0u8; 16]; - hex::decode_to_slice(text, &mut bytes).ok()?; - bytes - })) - }) - // Putting the secret into this message is a security hazard! - .ok_or(GetAuthInfoError::BadSecret) - } + info!("got a secret: {entry}"); // safe since it's not a prod scenario + let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram); + Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5))) } + .map_err(crate::error::log_error) + .instrument(info_span!("get_auth_info", mock = self.endpoint.as_str())) + .await } /// We don't need to wake anything locally, so we just return the connection info. - pub(super) async fn wake_compute(&self) -> Result { + pub async fn wake_compute(&self) -> Result { let mut config = compute::ConnCfg::new(); config .host(self.endpoint.host_str().unwrap_or("localhost")) @@ -93,3 +114,12 @@ impl<'a> Api<'a> { Ok(config) } } + +fn parse_md5(input: &str) -> Option<[u8; 16]> { + let text = input.strip_prefix("md5")?; + + let mut bytes = [0u8; 16]; + hex::decode_to_slice(text, &mut bytes).ok()?; + + Some(bytes) +} diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 865af4d2e5..d9ee50894d 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -89,7 +89,7 @@ impl AuthFlow<'_, S, PasswordHack> { /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result { + pub async fn authenticate(self) -> super::Result> { // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; let sasl = sasl::FirstMessage::parse(&msg) @@ -101,10 +101,10 @@ impl AuthFlow<'_, S, Scram<'_>> { } let secret = self.state.0; - let key = sasl::SaslStream::new(self.stream, sasl.message) + let outcome = sasl::SaslStream::new(self.stream, sasl.message) .authenticate(scram::Exchange::new(secret, rand::random, None)) .await?; - Ok(key) + Ok(outcome) } } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 0e376a37cd..f1cb44b1a8 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -1,4 +1,15 @@ -use std::io; +use std::{error::Error as StdError, fmt, io}; + +/// Upcast (almost) any error into an opaque [`io::Error`]. +pub fn io_error(e: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) +} + +/// A small combinator for pluggable error logging. +pub fn log_error(e: E) -> E { + tracing::error!("{e}"); + e +} /// Marks errors that may be safely shown to a client. /// This trait can be seen as a specialized version of [`ToString`]. @@ -6,7 +17,7 @@ use std::io; /// NOTE: This trait should not be implemented for [`anyhow::Error`], since it /// is way too convenient and tends to proliferate all across the codebase, /// ultimately leading to accidental leaks of sensitive data. -pub trait UserFacingError: ToString { +pub trait UserFacingError: fmt::Display { /// Format the error for client, stripping all sensitive info. /// /// Although this might be a no-op for many types, it's highly @@ -17,8 +28,3 @@ pub trait UserFacingError: ToString { self.to_string() } } - -/// Upcast (almost) any error into an opaque [`io::Error`]. -pub fn io_error(e: impl Into>) -> io::Error { - io::Error::new(io::ErrorKind::Other, e) -} diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 6f9145678b..096a33d73d 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -37,16 +37,6 @@ impl Endpoint { ) -> Result { self.client.execute(request).await } - - /// Execute a [request](reqwest::Request) and raise an error if status != 200. - pub async fn checked_execute( - &self, - request: reqwest::Request, - ) -> Result { - self.execute(request) - .await - .and_then(|r| r.error_for_status()) - } } #[cfg(test)] diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 411893fee5..da3cb144e3 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -49,17 +49,6 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { .unwrap() }); -/// A small combinator for pluggable error logging. -async fn log_error(future: F) -> F::Output -where - F: std::future::Future>, -{ - future.await.map_err(|err| { - error!("{err}"); - err - }) -} - pub async fn task_main( config: &'static ProxyConfig, listener: tokio::net::TcpListener, @@ -80,7 +69,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancel_map = Arc::clone(&cancel_map); tokio::spawn( - log_error(async move { + async move { info!("spawned a task for {peer_addr}"); socket @@ -88,6 +77,10 @@ pub async fn task_main( .context("failed to set socket option")?; handle_client(config, &cancel_map, session_id, socket).await + } + .unwrap_or_else(|e| { + // Acknowledge that the task has finished with an error. + error!("per-client task finished with an error: {e:#}"); }) .instrument(info_span!("client", session = format_args!("{session_id}"))), ); diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 3d74dbae5a..24fbc57b99 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -1,6 +1,6 @@ ///! A group of high-level tests for connection establishing logic and auth. use super::*; -use crate::{auth, scram}; +use crate::{auth, sasl, scram}; use async_trait::async_trait; use rstest::rstest; use tokio_postgres::config::SslMode; @@ -100,8 +100,7 @@ impl Scram { } fn mock(user: &str) -> Self { - let salt = rand::random::<[u8; 32]>(); - Scram(scram::ServerSecret::mock(user, &salt)) + Scram(scram::ServerSecret::mock(user, rand::random())) } } @@ -111,13 +110,17 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - auth::AuthFlow::new(stream) + let outcome = auth::AuthFlow::new(stream) .begin(auth::Scram(&self.0)) .await? .authenticate() .await?; - Ok(()) + use sasl::Outcome::*; + match outcome { + Success(_) => Ok(()), + Failure(reason) => bail!("autentication failed with an error: {reason}"), + } } } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index 689fca6049..6d1dd9fba5 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -16,22 +16,19 @@ use thiserror::Error; pub use channel_binding::ChannelBinding; pub use messages::FirstMessage; -pub use stream::SaslStream; +pub use stream::{Outcome, SaslStream}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub enum Error { - #[error("Failed to authenticate client: {0}")] - AuthenticationFailed(&'static str), - #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), #[error("Unsupported channel binding method: {0}")] ChannelBindingBadMethod(Box), - #[error("Bad client message")] - BadClientMessage, + #[error("Bad client message: {0}")] + BadClientMessage(&'static str), #[error(transparent)] Io(#[from] io::Error), @@ -41,8 +38,6 @@ impl UserFacingError for Error { fn to_string_client(&self) -> String { use Error::*; match self { - // This constructor contains the reason why auth has failed. - AuthenticationFailed(s) => s.to_string(), // TODO: add support for channel binding ChannelBindingFailed(_) => "channel binding is not supported yet".to_string(), ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"), @@ -55,11 +50,14 @@ impl UserFacingError for Error { pub type Result = std::result::Result; /// A result of one SASL exchange. +#[must_use] pub enum Step { /// We should continue exchanging messages. - Continue(T), + Continue(T, String), /// The client has been authenticated successfully. - Authenticated(R), + Success(R, String), + /// Authentication failed (reason attached). + Failure(&'static str), } /// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. @@ -69,5 +67,5 @@ pub trait Mechanism: Sized { /// Produce a server challenge to be sent to the client. /// This is how this method is called in PostgreSQL (`libpq/sasl.h`). - fn exchange(self, input: &str) -> Result<(Step, String)>; + fn exchange(self, input: &str) -> Result>; } diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 0e782c5f29..b24cc4bf44 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -48,28 +48,41 @@ impl SaslStream<'_, S> { } } +/// SASL authentication outcome. +/// It's much easier to match on those two variants +/// than to peek into a noisy protocol error type. +#[must_use = "caller must explicitly check for success"] +pub enum Outcome { + /// Authentication succeeded and produced some value. + Success(R), + /// Authentication failed (reason attached). + Failure(&'static str), +} + impl SaslStream<'_, S> { /// Perform SASL message exchange according to the underlying algorithm /// until user is either authenticated or denied access. pub async fn authenticate( mut self, mut mechanism: M, - ) -> super::Result { + ) -> super::Result> { loop { let input = self.recv().await?; - let (moved, reply) = mechanism.exchange(input)?; + let step = mechanism.exchange(input)?; - use super::Step::*; - match moved { - Continue(moved) => { + use super::Step; + return Ok(match step { + Step::Continue(moved_mechanism, reply) => { self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved; + mechanism = moved_mechanism; + continue; } - Authenticated(result) => { + Step::Success(result, reply) => { self.send(&ServerMessage::Final(&reply)).await?; - return Ok(result); + Outcome::Success(result) } - } + Step::Failure(reason) => Outcome::Failure(reason), + }); } } } diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index fca5585b25..882769a70d 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -64,12 +64,12 @@ impl<'a> Exchange<'a> { impl sasl::Mechanism for Exchange<'_> { type Output = super::ScramKey; - fn exchange(mut self, input: &str) -> sasl::Result<(sasl::Step, String)> { + fn exchange(mut self, input: &str) -> sasl::Result> { use {sasl::Step::*, ExchangeState::*}; match &self.state { Initial => { - let client_first_message = - ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + let client_first_message = ClientFirstMessage::parse(input) + .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?; let server_first_message = client_first_message.build_server_first_message( &(self.nonce)(), @@ -84,15 +84,15 @@ impl sasl::Mechanism for Exchange<'_> { server_first_message, }; - Ok((Continue(self), msg)) + Ok(Continue(self, msg)) } SaltSent { cbind_flag, client_first_message_bare, server_first_message, } => { - let client_final_message = - ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + let client_final_message = ClientFinalMessage::parse(input) + .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?; let channel_binding = cbind_flag.encode(|_| { self.cert_digest @@ -106,9 +106,7 @@ impl sasl::Mechanism for Exchange<'_> { } if client_final_message.nonce != server_first_message.nonce() { - return Err(SaslError::AuthenticationFailed( - "combined nonce doesn't match", - )); + return Err(SaslError::BadClientMessage("combined nonce doesn't match")); } let signature_builder = SignatureBuilder { @@ -121,14 +119,15 @@ impl sasl::Mechanism for Exchange<'_> { .build(&self.secret.stored_key) .derive_client_key(&client_final_message.proof); - if client_key.sha256() != self.secret.stored_key { - return Err(SaslError::AuthenticationFailed("password doesn't match")); + // Auth fails either if keys don't match or it's pre-determined to fail. + if client_key.sha256() != self.secret.stored_key || self.secret.doomed { + return Ok(Failure("password doesn't match")); } let msg = client_final_message .build_server_final_message(signature_builder, &self.secret.server_key); - Ok((Authenticated(client_key), msg)) + Ok(Success(client_key, msg)) } } } diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 765aef4443..89668465fa 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -14,6 +14,9 @@ pub struct ServerSecret { pub stored_key: ScramKey, /// Used by client to verify server's signature. pub server_key: ScramKey, + /// Should auth fail no matter what? + /// This is exactly the case for mocked secrets. + pub doomed: bool, } impl ServerSecret { @@ -30,6 +33,7 @@ impl ServerSecret { salt_base64: salt.to_owned(), stored_key: base64_decode_array(stored_key)?.into(), server_key: base64_decode_array(server_key)?.into(), + doomed: false, }; Some(secret) @@ -38,16 +42,16 @@ impl ServerSecret { /// To avoid revealing information to an attacker, we use a /// mocked server secret even if the user doesn't exist. /// See `auth-scram.c : mock_scram_secret` for details. - #[allow(dead_code)] - pub fn mock(user: &str, nonce: &[u8; 32]) -> Self { + pub fn mock(user: &str, nonce: [u8; 32]) -> Self { // Refer to `auth-scram.c : scram_mock_salt`. - let mocked_salt = super::sha256([user.as_bytes(), nonce]); + let mocked_salt = super::sha256([user.as_bytes(), &nonce]); Self { iterations: 4096, salt_base64: base64::encode(&mocked_salt), stored_key: ScramKey::default(), server_key: ScramKey::default(), + doomed: true, } } @@ -67,6 +71,7 @@ impl ServerSecret { salt_base64: base64::encode(&salt), stored_key: password.client_key().sha256(), server_key: password.server_key(), + doomed: false, }) } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 8e4084775c..19e1479068 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -109,8 +109,9 @@ impl PqStream { /// Write the error message using [`Self::write_message`], then re-throw it. /// Allowing string literals is safe under the assumption they might not contain any runtime info. + /// This method exists due to `&str` not implementing `Into`. pub async fn throw_error_str(&mut self, error: &'static str) -> anyhow::Result { - // This method exists due to `&str` not implementing `Into` + tracing::info!("forwarding error to user: {error}"); self.write_message(&BeMessage::ErrorResponse(error)).await?; bail!(error) } @@ -122,6 +123,7 @@ impl PqStream { E: UserFacingError + Into, { let msg = error.to_string_client(); + tracing::info!("forwarding error to user: {msg}"); self.write_message(&BeMessage::ErrorResponse(&msg)).await?; bail!(error) } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 0d64ca6d65..e3f8247274 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2092,62 +2092,73 @@ class PSQL: class NeonProxy(PgProtocol): + link_auth_uri: str = "http://dummy-uri" + + class AuthBackend(abc.ABC): + """All auth backends must inherit from this class""" + + @property + def default_conn_url(self) -> Optional[str]: + return None + + @abc.abstractmethod + def extra_args(self) -> list[str]: + pass + + class Link(AuthBackend): + def extra_args(self) -> list[str]: + return [ + # Link auth backend params + *["--auth-backend", "link"], + *["--uri", NeonProxy.link_auth_uri], + ] + + @dataclass(frozen=True) + class Postgres(AuthBackend): + pg_conn_url: str + + @property + def default_conn_url(self) -> Optional[str]: + return self.pg_conn_url + + def extra_args(self) -> list[str]: + return [ + # Postgres auth backend params + *["--auth-backend", "postgres"], + *["--auth-endpoint", self.pg_conn_url], + ] + def __init__( self, + neon_binpath: Path, proxy_port: int, http_port: int, mgmt_port: int, - neon_binpath: Path, - auth_endpoint=None, + auth_backend: NeonProxy.AuthBackend, ): - super().__init__(dsn=auth_endpoint, port=proxy_port) - self.host = "127.0.0.1" + host = "127.0.0.1" + super().__init__(dsn=auth_backend.default_conn_url, host=host, port=proxy_port) + + self.host = host self.http_port = http_port self.neon_binpath = neon_binpath self.proxy_port = proxy_port self.mgmt_port = mgmt_port - self.auth_endpoint = auth_endpoint + self.auth_backend = auth_backend self._popen: Optional[subprocess.Popen[bytes]] = None - self.link_auth_uri_prefix = "http://dummy-uri" - def start(self): - """ - Starts a proxy with option '--auth-backend postgres' and a postgres instance - already provided though '--auth-endpoint '." - """ + def start(self) -> NeonProxy: assert self._popen is None - assert self.auth_endpoint is not None - - # Start proxy args = [ str(self.neon_binpath / "proxy"), *["--http", f"{self.host}:{self.http_port}"], *["--proxy", f"{self.host}:{self.proxy_port}"], *["--mgmt", f"{self.host}:{self.mgmt_port}"], - *["--auth-backend", "postgres"], - *["--auth-endpoint", self.auth_endpoint], + *self.auth_backend.extra_args(), ] self._popen = subprocess.Popen(args) self._wait_until_ready() - - def start_with_link_auth(self): - """ - Starts a proxy with option '--auth-backend link' and a dummy authentication link '--uri dummy-auth-link'." - """ - assert self._popen is None - - # Start proxy - bin_proxy = str(self.neon_binpath / "proxy") - args = [bin_proxy] - args.extend(["--http", f"{self.host}:{self.http_port}"]) - args.extend(["--proxy", f"{self.host}:{self.proxy_port}"]) - args.extend(["--mgmt", f"{self.host}:{self.mgmt_port}"]) - args.extend(["--auth-backend", "link"]) - args.extend(["--uri", self.link_auth_uri_prefix]) - arg_str = " ".join(args) - log.info(f"starting proxy with command line ::: {arg_str}") - self._popen = subprocess.Popen(args, stdout=subprocess.PIPE) - self._wait_until_ready() + return self @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) def _wait_until_ready(self): @@ -2158,7 +2169,7 @@ class NeonProxy(PgProtocol): request_result.raise_for_status() return request_result.text - def __enter__(self) -> "NeonProxy": + def __enter__(self) -> NeonProxy: return self def __exit__( @@ -2176,11 +2187,19 @@ class NeonProxy(PgProtocol): @pytest.fixture(scope="function") def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterator[NeonProxy]: """Neon proxy that routes through link auth.""" + http_port = port_distributor.get_port() proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() - with NeonProxy(proxy_port, http_port, neon_binpath=neon_binpath, mgmt_port=mgmt_port) as proxy: - proxy.start_with_link_auth() + + with NeonProxy( + neon_binpath=neon_binpath, + proxy_port=proxy_port, + http_port=http_port, + mgmt_port=mgmt_port, + auth_backend=NeonProxy.Link(), + ) as proxy: + proxy.start() yield proxy @@ -2204,11 +2223,11 @@ def static_proxy( http_port = port_distributor.get_port() with NeonProxy( + neon_binpath=neon_binpath, proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, - neon_binpath=neon_binpath, - auth_endpoint=auth_endpoint, + auth_backend=NeonProxy.Postgres(auth_endpoint), ) as proxy: proxy.start() yield proxy diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index e868d6b616..eab9505fbb 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -28,61 +28,58 @@ def test_password_hack(static_proxy: NeonProxy): static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic) -def get_session_id(uri_prefix, uri_line): - assert uri_prefix in uri_line - - url_parts = urlparse(uri_line) - psql_session_id = url_parts.path[1:] - assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars" - - return psql_session_id - - -async def find_auth_link(link_auth_uri_prefix, proc): - for _ in range(100): - line = (await proc.stderr.readline()).decode("utf-8").strip() - log.info(f"psql line: {line}") - if link_auth_uri_prefix in line: - log.info(f"SUCCESS, found auth url: {line}") - return line - - -async def activate_link_auth(local_vanilla_pg, link_proxy, psql_session_id): - pg_user = "proxy" - - log.info("creating a new user for link auth test") - local_vanilla_pg.start() - local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") - - db_info = json.dumps( - { - "session_id": psql_session_id, - "result": { - "Success": { - "host": local_vanilla_pg.default_options["host"], - "port": local_vanilla_pg.default_options["port"], - "dbname": local_vanilla_pg.default_options["dbname"], - "user": pg_user, - "project": "irrelevant", - } - }, - } - ) - - log.info("sending session activation message") - psql = await PSQL(host=link_proxy.host, port=link_proxy.mgmt_port).run(db_info) - out = (await psql.stdout.read()).decode("utf-8").strip() - assert out == "ok" - - @pytest.mark.asyncio async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): + def get_session_id(uri_prefix, uri_line): + assert uri_prefix in uri_line + + url_parts = urlparse(uri_line) + psql_session_id = url_parts.path[1:] + assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars" + + return psql_session_id + + async def find_auth_link(link_auth_uri, proc): + for _ in range(100): + line = (await proc.stderr.readline()).decode("utf-8").strip() + log.info(f"psql line: {line}") + if link_auth_uri in line: + log.info(f"SUCCESS, found auth url: {line}") + return line + + async def activate_link_auth(local_vanilla_pg, link_proxy, psql_session_id): + pg_user = "proxy" + + log.info("creating a new user for link auth test") + local_vanilla_pg.start() + local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser") + + db_info = json.dumps( + { + "session_id": psql_session_id, + "result": { + "Success": { + "host": local_vanilla_pg.default_options["host"], + "port": local_vanilla_pg.default_options["port"], + "dbname": local_vanilla_pg.default_options["dbname"], + "user": pg_user, + "project": "irrelevant", + } + }, + } + ) + + log.info("sending session activation message") + psql = await PSQL(host=link_proxy.host, port=link_proxy.mgmt_port).run(db_info) + out = (await psql.stdout.read()).decode("utf-8").strip() + assert out == "ok" + psql = await PSQL(host=link_proxy.host, port=link_proxy.proxy_port).run("select 42") - uri_prefix = link_proxy.link_auth_uri_prefix - link = await find_auth_link(uri_prefix, psql) + base_uri = link_proxy.link_auth_uri + link = await find_auth_link(base_uri, psql) - psql_session_id = get_session_id(uri_prefix, link) + psql_session_id = get_session_id(base_uri, link) await activate_link_auth(vanilla_pg, link_proxy, psql_session_id) assert psql.stdout is not None @@ -97,3 +94,31 @@ def test_proxy_options(static_proxy: NeonProxy): cur.execute("SHOW proxytest.option") value = cur.fetchall()[0][0] assert value == "value" + + +def test_auth_errors(static_proxy: NeonProxy): + # User does not exist + with pytest.raises(psycopg2.Error) as exprinfo: + static_proxy.connect(user="pinocchio", options="project=irrelevant") + text = str(exprinfo.value).strip() + assert text.endswith("password authentication failed for user 'pinocchio'") + + static_proxy.safe_psql( + "create role pinocchio with login password 'magic'", options="project=irrelevant" + ) + + # User exists, but password is missing + with pytest.raises(psycopg2.Error) as exprinfo: + static_proxy.connect(user="pinocchio", password=None, options="project=irrelevant") + text = str(exprinfo.value).strip() + assert text.endswith("password authentication failed for user 'pinocchio'") + + # User exists, but password is wrong + with pytest.raises(psycopg2.Error) as exprinfo: + static_proxy.connect(user="pinocchio", password="bad", options="project=irrelevant") + text = str(exprinfo.value).strip() + assert text.endswith("password authentication failed for user 'pinocchio'") + + # Finally, check that the user can connect + with static_proxy.connect(user="pinocchio", password="magic", options="project=irrelevant"): + pass