[proxy] Propagate more console API errors to the user

This patch aims to fix some of the inconsistencies in error reporting,
for example "Internal error" or "Console request failed" instead of
"password authentication failed for user '<NAME>'".
This commit is contained in:
Dmitry Ivanov
2022-11-03 18:07:16 +03:00
parent e5d523c86a
commit 607c0facfc
15 changed files with 504 additions and 301 deletions

View File

@@ -49,6 +49,9 @@ pub enum AuthErrorImpl {
)]
MissingProjectName,
#[error("password authentication failed for user '{0}'")]
AuthFailed(Box<str>),
/// Errors produced by e.g. [`crate::stream::PqStream`].
#[error(transparent)]
Io(#[from] io::Error),
@@ -62,6 +65,10 @@ impl AuthError {
pub fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
AuthErrorImpl::BadAuthMethod(name.into()).into()
}
pub fn auth_failed(user: impl Into<Box<str>>) -> Self {
AuthErrorImpl::AuthFailed(user.into()).into()
}
}
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
@@ -78,10 +85,11 @@ impl UserFacingError for AuthError {
GetAuthInfo(e) => e.to_string_client(),
WakeCompute(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
AuthFailed(_) => self.to_string(),
BadAuthMethod(_) => self.to_string(),
MalformedPassword(_) => self.to_string(),
MissingProjectName => self.to_string(),
_ => "Internal error".to_string(),
Io(_) => "Internal error".to_string(),
}
}
}

View File

@@ -5,26 +5,74 @@ use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
error::{io_error, UserFacingError},
http, scram,
http, sasl, scram,
stream::PqStream,
};
use futures::TryFutureExt;
use serde::{Deserialize, Serialize};
use reqwest::StatusCode as HttpStatusCode;
use serde::Deserialize;
use std::future::Future;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, info_span};
use tracing::{error, info, info_span, warn, Instrument};
/// A go-to error message which doesn't leak any detail.
const REQUEST_FAILED: &str = "Console request failed";
/// Common console API error.
#[derive(Debug, Error)]
#[error("{}", REQUEST_FAILED)]
pub struct TransportError(#[from] std::io::Error);
pub enum ApiError {
/// Error returned by the console itself.
#[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
Console {
status: HttpStatusCode,
text: Box<str>,
},
impl UserFacingError for TransportError {}
/// Various IO errors like broken pipe or malformed payload.
#[error("{REQUEST_FAILED}: {0}")]
Transport(#[from] std::io::Error),
}
impl ApiError {
/// Returns HTTP status code if it's the reason for failure.
fn http_status_code(&self) -> Option<HttpStatusCode> {
use ApiError::*;
match self {
Console { status, .. } => Some(*status),
_ => None,
}
}
}
impl UserFacingError for ApiError {
fn to_string_client(&self) -> String {
use ApiError::*;
match self {
// To minimize risks, only select errors are forwarded to users.
// Ask @neondatabase/control-plane for review before adding more.
Console { status, .. } => match *status {
HttpStatusCode::NOT_FOUND => {
// Status 404: failed to get a project-related resource.
format!("{REQUEST_FAILED}: endpoint cannot be found")
}
HttpStatusCode::NOT_ACCEPTABLE => {
// Status 406: endpoint is disabled (we don't allow connections).
format!("{REQUEST_FAILED}: endpoint is disabled")
}
HttpStatusCode::LOCKED => {
// Status 423: project might be in maintenance mode (or bad state).
format!("{REQUEST_FAILED}: endpoint is temporary unavailable")
}
_ => REQUEST_FAILED.to_owned(),
},
_ => REQUEST_FAILED.to_owned(),
}
}
}
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
impl From<reqwest::Error> for TransportError {
impl From<reqwest::Error> for ApiError {
fn from(e: reqwest::Error) -> Self {
io_error(e).into()
}
@@ -37,61 +85,73 @@ pub enum GetAuthInfoError {
BadSecret,
#[error(transparent)]
Transport(TransportError),
ApiError(ApiError),
}
// This allows more useful interactions than `#[from]`.
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
fn from(e: E) -> Self {
Self::ApiError(e.into())
}
}
impl UserFacingError for GetAuthInfoError {
fn to_string_client(&self) -> String {
use GetAuthInfoError::*;
match self {
// We absolutely should not leak any secrets!
BadSecret => REQUEST_FAILED.to_owned(),
Transport(e) => e.to_string_client(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
}
impl<E: Into<TransportError>> From<E> for GetAuthInfoError {
fn from(e: E) -> Self {
Self::Transport(e.into())
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
// We shouldn't show users the address even if it's broken.
#[error("Console responded with a malformed compute address: {0}")]
BadComputeAddress(String),
BadComputeAddress(Box<str>),
#[error(transparent)]
Transport(TransportError),
ApiError(ApiError),
}
// This allows more useful interactions than `#[from]`.
impl<E: Into<ApiError>> From<E> for WakeComputeError {
fn from(e: E) -> Self {
Self::ApiError(e.into())
}
}
impl UserFacingError for WakeComputeError {
fn to_string_client(&self) -> String {
use WakeComputeError::*;
match self {
// We shouldn't show user the address even if it's broken.
// Besides, user is unlikely to care about this detail.
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
Transport(e) => e.to_string_client(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
}
}
}
impl<E: Into<TransportError>> From<E> for WakeComputeError {
fn from(e: E) -> Self {
Self::Transport(e.into())
}
/// Console's response which holds client's auth secret.
#[derive(Deserialize, Debug)]
struct GetRoleSecret {
role_secret: Box<str>,
}
// TODO: convert into an enum with "error"
#[derive(Serialize, Deserialize, Debug)]
struct GetRoleSecretResponse {
role_secret: String,
/// Console's response which holds compute node's `host:port` pair.
#[derive(Deserialize, Debug)]
struct WakeCompute {
address: Box<str>,
}
// TODO: convert into an enum with "error"
#[derive(Serialize, Deserialize, Debug)]
struct GetWakeComputeResponse {
address: String,
/// Console's error response with human-readable description.
#[derive(Deserialize, Debug)]
struct ConsoleError {
error: Box<str>,
}
/// Auth secret which is managed by the cloud.
@@ -110,6 +170,12 @@ pub(super) struct Api<'a> {
creds: &'a ClientCredentials<'a>,
}
impl<'a> AsRef<ClientCredentials<'a>> for Api<'a> {
fn as_ref(&self) -> &ClientCredentials<'a> {
self.creds
}
}
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(
@@ -126,83 +192,88 @@ impl<'a> Api<'a> {
/// Authenticate the existing user or throw an error.
pub(super) async fn handle_user(
self,
&'a self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<AuthSuccess<compute::ConnCfg>> {
handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
}
}
async fn get_auth_info(&self) -> Result<AuthInfo, GetAuthInfoError> {
impl Api<'_> {
async fn get_auth_info(&self) -> Result<Option<AuthInfo>, GetAuthInfoError> {
let request_id = uuid::Uuid::new_v4().to_string();
let req = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
("role", Some(self.creds.user)),
])
.build()?;
async {
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
("role", Some(self.creds.user)),
])
.build()?;
let span = info_span!("http", id = request_id, url = req.url().as_str());
info!(parent: &span, "request auth info");
let msg = self
.endpoint
.checked_execute(req)
.and_then(|r| r.json::<GetRoleSecretResponse>())
.await
.map_err(|e| {
error!(parent: &span, "{e}");
e
})?;
info!(url = request.url().as_str(), "sending http request");
let response = self.endpoint.execute(request).await?;
let body = match parse_body::<GetRoleSecret>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(HttpStatusCode::NOT_FOUND) => return Ok(None),
_otherwise => return Err(e.into()),
},
};
scram::ServerSecret::parse(&msg.role_secret)
.map(AuthInfo::Scram)
.ok_or(GetAuthInfoError::BadSecret)
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthInfo::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
Ok(Some(secret))
}
.map_err(crate::error::log_error)
.instrument(info_span!("get_auth_info", id = request_id))
.await
}
/// Wake up the compute node and return the corresponding connection info.
pub(super) async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
pub async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
let request_id = uuid::Uuid::new_v4().to_string();
let req = self
.endpoint
.get("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
])
.build()?;
async {
let request = self
.endpoint
.get("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.query(&[("session_id", self.extra.session_id)])
.query(&[
("application_name", self.extra.application_name),
("project", Some(self.creds.project().expect("impossible"))),
])
.build()?;
let span = info_span!("http", id = request_id, url = req.url().as_str());
info!(parent: &span, "request wake-up");
let msg = self
.endpoint
.checked_execute(req)
.and_then(|r| r.json::<GetWakeComputeResponse>())
.await
.map_err(|e| {
error!(parent: &span, "{e}");
e
})?;
info!(url = request.url().as_str(), "sending http request");
let response = self.endpoint.execute(request).await?;
let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&msg.address) {
None => return Err(WakeComputeError::BadComputeAddress(msg.address)),
Some(x) => x,
};
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
let mut config = compute::ConnCfg::new();
config
.host(host)
.port(port)
.dbname(self.creds.dbname)
.user(self.creds.user);
let mut config = compute::ConnCfg::new();
config
.host(host)
.port(port)
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(config)
Ok(config)
}
.map_err(crate::error::log_error)
.instrument(info_span!("wake_compute", id = request_id))
.await
}
}
@@ -215,24 +286,40 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
) -> auth::Result<AuthSuccess<compute::ConnCfg>>
where
GetAuthInfo: Future<Output = Result<AuthInfo, GetAuthInfoError>>,
Endpoint: AsRef<ClientCredentials<'a>>,
GetAuthInfo: Future<Output = Result<Option<AuthInfo>, GetAuthInfoError>>,
WakeCompute: Future<Output = Result<compute::ConnCfg, WakeComputeError>>,
{
let creds = endpoint.as_ref();
info!("fetching user's authentication info");
let auth_info = get_auth_info(endpoint).await?;
let info = get_auth_info(endpoint).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
let flow = AuthFlow::new(client);
let scram_keys = match auth_info {
let scram_keys = match info {
AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthInfo::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);
let client_key = match flow.begin(scram).await?.authenticate().await? {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(creds.user));
}
};
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
@@ -249,6 +336,31 @@ where
})
}
/// Parse http response body, taking status code into account.
async fn parse_body<T: for<'a> Deserialize<'a>>(
response: reqwest::Response,
) -> Result<T, ApiError> {
let status = response.status();
if status.is_success() {
// We shouldn't log raw body because it may contain secrets.
info!("request succeeded, processing the body");
return Ok(response.json().await?);
}
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let body = response.json().await.unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
ConsoleError {
error: "reason unclear (malformed error message)".into(),
}
});
let text = body.error;
error!("console responded with an error ({status}): {text}");
Err(ApiError::Console { status, text })
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host, port.parse().ok()?))

View File

@@ -1,7 +1,7 @@
//! Local mock of Cloud API V2.
use super::{
console::{self, AuthInfo, GetAuthInfoError, TransportError, WakeComputeError},
console::{self, AuthInfo, GetAuthInfoError, WakeComputeError},
AuthSuccess,
};
use crate::{
@@ -12,7 +12,28 @@ use crate::{
stream::PqStream,
url::ApiUrl,
};
use futures::TryFutureExt;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span, warn, Instrument};
#[derive(Debug, Error)]
enum MockApiError {
#[error("Failed to read password: {0}")]
PasswordNotSet(tokio_postgres::Error),
}
impl From<MockApiError> for console::ApiError {
fn from(e: MockApiError) -> Self {
io_error(e).into()
}
}
impl From<tokio_postgres::Error> for console::ApiError {
fn from(e: tokio_postgres::Error) -> Self {
io_error(e).into()
}
}
#[must_use]
pub(super) struct Api<'a> {
@@ -20,10 +41,9 @@ pub(super) struct Api<'a> {
creds: &'a ClientCredentials<'a>,
}
// Helps eliminate graceless `.map_err` calls without introducing another ctor.
impl From<tokio_postgres::Error> for TransportError {
fn from(e: tokio_postgres::Error) -> Self {
io_error(e).into()
impl<'a> AsRef<ClientCredentials<'a>> for Api<'a> {
fn as_ref(&self) -> &ClientCredentials<'a> {
self.creds
}
}
@@ -35,54 +55,55 @@ impl<'a> Api<'a> {
/// Authenticate the existing user or throw an error.
pub(super) async fn handle_user(
self,
&'a self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<AuthSuccess<compute::ConnCfg>> {
// We reuse user handling logic from a production module.
console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
}
}
impl Api<'_> {
/// This implementation fetches the auth info from a local postgres instance.
async fn get_auth_info(&self) -> Result<AuthInfo, GetAuthInfoError> {
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
async fn get_auth_info(&self) -> Result<Option<AuthInfo>, GetAuthInfoError> {
async {
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client.query(query, &[&self.creds.user]).await?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client.query(query, &[&self.creds.user]).await?;
match &rows[..] {
// We can't get a secret if there's no such user.
[] => Err(io_error(format!("unknown user '{}'", self.creds.user)).into()),
// We can get at most one row, because `rolname` is unique.
let row = match rows.get(0) {
Some(row) => row,
// This means that the user doesn't exist, so there can be no secret.
// However, this is still a *valid* outcome which is very similar
// to getting `404 Not found` from the Neon console.
None => {
warn!("user '{}' does not exist", self.creds.user);
return Ok(None);
}
};
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row
.try_get("rolpassword")
.map_err(|e| io_error(format!("failed to read user's password: {e}")))?;
let entry = row
.try_get("rolpassword")
.map_err(MockApiError::PasswordNotSet)?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
// It could be an md5 hash if it's not a SCRAM secret.
let text = entry.strip_prefix("md5")?;
Some(AuthInfo::Md5({
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
bytes
}))
})
// Putting the secret into this message is a security hazard!
.ok_or(GetAuthInfoError::BadSecret)
}
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(entry).map(AuthInfo::Scram);
Ok(secret.or_else(|| parse_md5(entry).map(AuthInfo::Md5)))
}
.map_err(crate::error::log_error)
.instrument(info_span!("get_auth_info", mock = self.endpoint.as_str()))
.await
}
/// We don't need to wake anything locally, so we just return the connection info.
pub(super) async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
pub async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
@@ -93,3 +114,12 @@ impl<'a> Api<'a> {
Ok(config)
}
}
fn parse_md5(input: &str) -> Option<[u8; 16]> {
let text = input.strip_prefix("md5")?;
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
Some(bytes)
}

View File

@@ -89,7 +89,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<scram::ScramKey> {
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
@@ -101,10 +101,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
let secret = self.state.0;
let key = sasl::SaslStream::new(self.stream, sasl.message)
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(secret, rand::random, None))
.await?;
Ok(key)
Ok(outcome)
}
}

View File

@@ -1,4 +1,15 @@
use std::io;
use std::{error::Error as StdError, fmt, io};
/// Upcast (almost) any error into an opaque [`io::Error`].
pub fn io_error(e: impl Into<Box<dyn StdError + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
/// A small combinator for pluggable error logging.
pub fn log_error<E: fmt::Display>(e: E) -> E {
tracing::error!("{e}");
e
}
/// Marks errors that may be safely shown to a client.
/// This trait can be seen as a specialized version of [`ToString`].
@@ -6,7 +17,7 @@ use std::io;
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
/// is way too convenient and tends to proliferate all across the codebase,
/// ultimately leading to accidental leaks of sensitive data.
pub trait UserFacingError: ToString {
pub trait UserFacingError: fmt::Display {
/// Format the error for client, stripping all sensitive info.
///
/// Although this might be a no-op for many types, it's highly
@@ -17,8 +28,3 @@ pub trait UserFacingError: ToString {
self.to_string()
}
}
/// Upcast (almost) any error into an opaque [`io::Error`].
pub fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}

View File

@@ -37,16 +37,6 @@ impl Endpoint {
) -> Result<reqwest::Response, reqwest::Error> {
self.client.execute(request).await
}
/// Execute a [request](reqwest::Request) and raise an error if status != 200.
pub async fn checked_execute(
&self,
request: reqwest::Request,
) -> Result<reqwest::Response, reqwest::Error> {
self.execute(request)
.await
.and_then(|r| r.error_for_status())
}
}
#[cfg(test)]

View File

@@ -49,17 +49,6 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
/// A small combinator for pluggable error logging.
async fn log_error<R, F>(future: F) -> F::Output
where
F: std::future::Future<Output = anyhow::Result<R>>,
{
future.await.map_err(|err| {
error!("{err}");
err
})
}
pub async fn task_main(
config: &'static ProxyConfig,
listener: tokio::net::TcpListener,
@@ -80,7 +69,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
tokio::spawn(
log_error(async move {
async move {
info!("spawned a task for {peer_addr}");
socket
@@ -88,6 +77,10 @@ pub async fn task_main(
.context("failed to set socket option")?;
handle_client(config, &cancel_map, session_id, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
.instrument(info_span!("client", session = format_args!("{session_id}"))),
);

View File

@@ -1,6 +1,6 @@
///! A group of high-level tests for connection establishing logic and auth.
use super::*;
use crate::{auth, scram};
use crate::{auth, sasl, scram};
use async_trait::async_trait;
use rstest::rstest;
use tokio_postgres::config::SslMode;
@@ -100,8 +100,7 @@ impl Scram {
}
fn mock(user: &str) -> Self {
let salt = rand::random::<[u8; 32]>();
Scram(scram::ServerSecret::mock(user, &salt))
Scram(scram::ServerSecret::mock(user, rand::random()))
}
}
@@ -111,13 +110,17 @@ impl TestAuth for Scram {
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
auth::AuthFlow::new(stream)
let outcome = auth::AuthFlow::new(stream)
.begin(auth::Scram(&self.0))
.await?
.authenticate()
.await?;
Ok(())
use sasl::Outcome::*;
match outcome {
Success(_) => Ok(()),
Failure(reason) => bail!("autentication failed with an error: {reason}"),
}
}
}

View File

@@ -16,22 +16,19 @@ use thiserror::Error;
pub use channel_binding::ChannelBinding;
pub use messages::FirstMessage;
pub use stream::SaslStream;
pub use stream::{Outcome, SaslStream};
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]
pub enum Error {
#[error("Failed to authenticate client: {0}")]
AuthenticationFailed(&'static str),
#[error("Channel binding failed: {0}")]
ChannelBindingFailed(&'static str),
#[error("Unsupported channel binding method: {0}")]
ChannelBindingBadMethod(Box<str>),
#[error("Bad client message")]
BadClientMessage,
#[error("Bad client message: {0}")]
BadClientMessage(&'static str),
#[error(transparent)]
Io(#[from] io::Error),
@@ -41,8 +38,6 @@ impl UserFacingError for Error {
fn to_string_client(&self) -> String {
use Error::*;
match self {
// This constructor contains the reason why auth has failed.
AuthenticationFailed(s) => s.to_string(),
// TODO: add support for channel binding
ChannelBindingFailed(_) => "channel binding is not supported yet".to_string(),
ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
@@ -55,11 +50,14 @@ impl UserFacingError for Error {
pub type Result<T> = std::result::Result<T, Error>;
/// A result of one SASL exchange.
#[must_use]
pub enum Step<T, R> {
/// We should continue exchanging messages.
Continue(T),
Continue(T, String),
/// The client has been authenticated successfully.
Authenticated(R),
Success(R, String),
/// Authentication failed (reason attached).
Failure(&'static str),
}
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
@@ -69,5 +67,5 @@ pub trait Mechanism: Sized {
/// Produce a server challenge to be sent to the client.
/// This is how this method is called in PostgreSQL (`libpq/sasl.h`).
fn exchange(self, input: &str) -> Result<(Step<Self, Self::Output>, String)>;
fn exchange(self, input: &str) -> Result<Step<Self, Self::Output>>;
}

View File

@@ -48,28 +48,41 @@ impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
}
}
/// SASL authentication outcome.
/// It's much easier to match on those two variants
/// than to peek into a noisy protocol error type.
#[must_use = "caller must explicitly check for success"]
pub enum Outcome<R> {
/// Authentication succeeded and produced some value.
Success(R),
/// Authentication failed (reason attached).
Failure(&'static str),
}
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> super::Result<M::Output> {
) -> super::Result<Outcome<M::Output>> {
loop {
let input = self.recv().await?;
let (moved, reply) = mechanism.exchange(input)?;
let step = mechanism.exchange(input)?;
use super::Step::*;
match moved {
Continue(moved) => {
use super::Step;
return Ok(match step {
Step::Continue(moved_mechanism, reply) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved;
mechanism = moved_mechanism;
continue;
}
Authenticated(result) => {
Step::Success(result, reply) => {
self.send(&ServerMessage::Final(&reply)).await?;
return Ok(result);
Outcome::Success(result)
}
}
Step::Failure(reason) => Outcome::Failure(reason),
});
}
}
}

View File

@@ -64,12 +64,12 @@ impl<'a> Exchange<'a> {
impl sasl::Mechanism for Exchange<'_> {
type Output = super::ScramKey;
fn exchange(mut self, input: &str) -> sasl::Result<(sasl::Step<Self, Self::Output>, String)> {
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
use {sasl::Step::*, ExchangeState::*};
match &self.state {
Initial => {
let client_first_message =
ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
let client_first_message = ClientFirstMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
@@ -84,15 +84,15 @@ impl sasl::Mechanism for Exchange<'_> {
server_first_message,
};
Ok((Continue(self), msg))
Ok(Continue(self, msg))
}
SaltSent {
cbind_flag,
client_first_message_bare,
server_first_message,
} => {
let client_final_message =
ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
let client_final_message = ClientFinalMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| {
self.cert_digest
@@ -106,9 +106,7 @@ impl sasl::Mechanism for Exchange<'_> {
}
if client_final_message.nonce != server_first_message.nonce() {
return Err(SaslError::AuthenticationFailed(
"combined nonce doesn't match",
));
return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
}
let signature_builder = SignatureBuilder {
@@ -121,14 +119,15 @@ impl sasl::Mechanism for Exchange<'_> {
.build(&self.secret.stored_key)
.derive_client_key(&client_final_message.proof);
if client_key.sha256() != self.secret.stored_key {
return Err(SaslError::AuthenticationFailed("password doesn't match"));
// Auth fails either if keys don't match or it's pre-determined to fail.
if client_key.sha256() != self.secret.stored_key || self.secret.doomed {
return Ok(Failure("password doesn't match"));
}
let msg = client_final_message
.build_server_final_message(signature_builder, &self.secret.server_key);
Ok((Authenticated(client_key), msg))
Ok(Success(client_key, msg))
}
}
}

View File

@@ -14,6 +14,9 @@ pub struct ServerSecret {
pub stored_key: ScramKey,
/// Used by client to verify server's signature.
pub server_key: ScramKey,
/// Should auth fail no matter what?
/// This is exactly the case for mocked secrets.
pub doomed: bool,
}
impl ServerSecret {
@@ -30,6 +33,7 @@ impl ServerSecret {
salt_base64: salt.to_owned(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
};
Some(secret)
@@ -38,16 +42,16 @@ impl ServerSecret {
/// To avoid revealing information to an attacker, we use a
/// mocked server secret even if the user doesn't exist.
/// See `auth-scram.c : mock_scram_secret` for details.
#[allow(dead_code)]
pub fn mock(user: &str, nonce: &[u8; 32]) -> Self {
pub fn mock(user: &str, nonce: [u8; 32]) -> Self {
// Refer to `auth-scram.c : scram_mock_salt`.
let mocked_salt = super::sha256([user.as_bytes(), nonce]);
let mocked_salt = super::sha256([user.as_bytes(), &nonce]);
Self {
iterations: 4096,
salt_base64: base64::encode(&mocked_salt),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
}
}
@@ -67,6 +71,7 @@ impl ServerSecret {
salt_base64: base64::encode(&salt),
stored_key: password.client_key().sha256(),
server_key: password.server_key(),
doomed: false,
})
}
}

View File

@@ -109,8 +109,9 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
// This method exists due to `&str` not implementing `Into<anyhow::Error>`
tracing::info!("forwarding error to user: {error}");
self.write_message(&BeMessage::ErrorResponse(error)).await?;
bail!(error)
}
@@ -122,6 +123,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
E: UserFacingError + Into<anyhow::Error>,
{
let msg = error.to_string_client();
tracing::info!("forwarding error to user: {msg}");
self.write_message(&BeMessage::ErrorResponse(&msg)).await?;
bail!(error)
}

View File

@@ -2092,62 +2092,73 @@ class PSQL:
class NeonProxy(PgProtocol):
link_auth_uri: str = "http://dummy-uri"
class AuthBackend(abc.ABC):
"""All auth backends must inherit from this class"""
@property
def default_conn_url(self) -> Optional[str]:
return None
@abc.abstractmethod
def extra_args(self) -> list[str]:
pass
class Link(AuthBackend):
def extra_args(self) -> list[str]:
return [
# Link auth backend params
*["--auth-backend", "link"],
*["--uri", NeonProxy.link_auth_uri],
]
@dataclass(frozen=True)
class Postgres(AuthBackend):
pg_conn_url: str
@property
def default_conn_url(self) -> Optional[str]:
return self.pg_conn_url
def extra_args(self) -> list[str]:
return [
# Postgres auth backend params
*["--auth-backend", "postgres"],
*["--auth-endpoint", self.pg_conn_url],
]
def __init__(
self,
neon_binpath: Path,
proxy_port: int,
http_port: int,
mgmt_port: int,
neon_binpath: Path,
auth_endpoint=None,
auth_backend: NeonProxy.AuthBackend,
):
super().__init__(dsn=auth_endpoint, port=proxy_port)
self.host = "127.0.0.1"
host = "127.0.0.1"
super().__init__(dsn=auth_backend.default_conn_url, host=host, port=proxy_port)
self.host = host
self.http_port = http_port
self.neon_binpath = neon_binpath
self.proxy_port = proxy_port
self.mgmt_port = mgmt_port
self.auth_endpoint = auth_endpoint
self.auth_backend = auth_backend
self._popen: Optional[subprocess.Popen[bytes]] = None
self.link_auth_uri_prefix = "http://dummy-uri"
def start(self):
"""
Starts a proxy with option '--auth-backend postgres' and a postgres instance
already provided though '--auth-endpoint <postgress-instance>'."
"""
def start(self) -> NeonProxy:
assert self._popen is None
assert self.auth_endpoint is not None
# Start proxy
args = [
str(self.neon_binpath / "proxy"),
*["--http", f"{self.host}:{self.http_port}"],
*["--proxy", f"{self.host}:{self.proxy_port}"],
*["--mgmt", f"{self.host}:{self.mgmt_port}"],
*["--auth-backend", "postgres"],
*["--auth-endpoint", self.auth_endpoint],
*self.auth_backend.extra_args(),
]
self._popen = subprocess.Popen(args)
self._wait_until_ready()
def start_with_link_auth(self):
"""
Starts a proxy with option '--auth-backend link' and a dummy authentication link '--uri dummy-auth-link'."
"""
assert self._popen is None
# Start proxy
bin_proxy = str(self.neon_binpath / "proxy")
args = [bin_proxy]
args.extend(["--http", f"{self.host}:{self.http_port}"])
args.extend(["--proxy", f"{self.host}:{self.proxy_port}"])
args.extend(["--mgmt", f"{self.host}:{self.mgmt_port}"])
args.extend(["--auth-backend", "link"])
args.extend(["--uri", self.link_auth_uri_prefix])
arg_str = " ".join(args)
log.info(f"starting proxy with command line ::: {arg_str}")
self._popen = subprocess.Popen(args, stdout=subprocess.PIPE)
self._wait_until_ready()
return self
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10)
def _wait_until_ready(self):
@@ -2158,7 +2169,7 @@ class NeonProxy(PgProtocol):
request_result.raise_for_status()
return request_result.text
def __enter__(self) -> "NeonProxy":
def __enter__(self) -> NeonProxy:
return self
def __exit__(
@@ -2176,11 +2187,19 @@ class NeonProxy(PgProtocol):
@pytest.fixture(scope="function")
def link_proxy(port_distributor: PortDistributor, neon_binpath: Path) -> Iterator[NeonProxy]:
"""Neon proxy that routes through link auth."""
http_port = port_distributor.get_port()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
with NeonProxy(proxy_port, http_port, neon_binpath=neon_binpath, mgmt_port=mgmt_port) as proxy:
proxy.start_with_link_auth()
with NeonProxy(
neon_binpath=neon_binpath,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
auth_backend=NeonProxy.Link(),
) as proxy:
proxy.start()
yield proxy
@@ -2204,11 +2223,11 @@ def static_proxy(
http_port = port_distributor.get_port()
with NeonProxy(
neon_binpath=neon_binpath,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
neon_binpath=neon_binpath,
auth_endpoint=auth_endpoint,
auth_backend=NeonProxy.Postgres(auth_endpoint),
) as proxy:
proxy.start()
yield proxy

View File

@@ -28,61 +28,58 @@ def test_password_hack(static_proxy: NeonProxy):
static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic)
def get_session_id(uri_prefix, uri_line):
assert uri_prefix in uri_line
url_parts = urlparse(uri_line)
psql_session_id = url_parts.path[1:]
assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars"
return psql_session_id
async def find_auth_link(link_auth_uri_prefix, proc):
for _ in range(100):
line = (await proc.stderr.readline()).decode("utf-8").strip()
log.info(f"psql line: {line}")
if link_auth_uri_prefix in line:
log.info(f"SUCCESS, found auth url: {line}")
return line
async def activate_link_auth(local_vanilla_pg, link_proxy, psql_session_id):
pg_user = "proxy"
log.info("creating a new user for link auth test")
local_vanilla_pg.start()
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
db_info = json.dumps(
{
"session_id": psql_session_id,
"result": {
"Success": {
"host": local_vanilla_pg.default_options["host"],
"port": local_vanilla_pg.default_options["port"],
"dbname": local_vanilla_pg.default_options["dbname"],
"user": pg_user,
"project": "irrelevant",
}
},
}
)
log.info("sending session activation message")
psql = await PSQL(host=link_proxy.host, port=link_proxy.mgmt_port).run(db_info)
out = (await psql.stdout.read()).decode("utf-8").strip()
assert out == "ok"
@pytest.mark.asyncio
async def test_psql_session_id(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy):
def get_session_id(uri_prefix, uri_line):
assert uri_prefix in uri_line
url_parts = urlparse(uri_line)
psql_session_id = url_parts.path[1:]
assert psql_session_id.isalnum(), "session_id should only contain alphanumeric chars"
return psql_session_id
async def find_auth_link(link_auth_uri, proc):
for _ in range(100):
line = (await proc.stderr.readline()).decode("utf-8").strip()
log.info(f"psql line: {line}")
if link_auth_uri in line:
log.info(f"SUCCESS, found auth url: {line}")
return line
async def activate_link_auth(local_vanilla_pg, link_proxy, psql_session_id):
pg_user = "proxy"
log.info("creating a new user for link auth test")
local_vanilla_pg.start()
local_vanilla_pg.safe_psql(f"create user {pg_user} with login superuser")
db_info = json.dumps(
{
"session_id": psql_session_id,
"result": {
"Success": {
"host": local_vanilla_pg.default_options["host"],
"port": local_vanilla_pg.default_options["port"],
"dbname": local_vanilla_pg.default_options["dbname"],
"user": pg_user,
"project": "irrelevant",
}
},
}
)
log.info("sending session activation message")
psql = await PSQL(host=link_proxy.host, port=link_proxy.mgmt_port).run(db_info)
out = (await psql.stdout.read()).decode("utf-8").strip()
assert out == "ok"
psql = await PSQL(host=link_proxy.host, port=link_proxy.proxy_port).run("select 42")
uri_prefix = link_proxy.link_auth_uri_prefix
link = await find_auth_link(uri_prefix, psql)
base_uri = link_proxy.link_auth_uri
link = await find_auth_link(base_uri, psql)
psql_session_id = get_session_id(uri_prefix, link)
psql_session_id = get_session_id(base_uri, link)
await activate_link_auth(vanilla_pg, link_proxy, psql_session_id)
assert psql.stdout is not None
@@ -97,3 +94,31 @@ def test_proxy_options(static_proxy: NeonProxy):
cur.execute("SHOW proxytest.option")
value = cur.fetchall()[0][0]
assert value == "value"
def test_auth_errors(static_proxy: NeonProxy):
# User does not exist
with pytest.raises(psycopg2.Error) as exprinfo:
static_proxy.connect(user="pinocchio", options="project=irrelevant")
text = str(exprinfo.value).strip()
assert text.endswith("password authentication failed for user 'pinocchio'")
static_proxy.safe_psql(
"create role pinocchio with login password 'magic'", options="project=irrelevant"
)
# User exists, but password is missing
with pytest.raises(psycopg2.Error) as exprinfo:
static_proxy.connect(user="pinocchio", password=None, options="project=irrelevant")
text = str(exprinfo.value).strip()
assert text.endswith("password authentication failed for user 'pinocchio'")
# User exists, but password is wrong
with pytest.raises(psycopg2.Error) as exprinfo:
static_proxy.connect(user="pinocchio", password="bad", options="project=irrelevant")
text = str(exprinfo.value).strip()
assert text.endswith("password authentication failed for user 'pinocchio'")
# Finally, check that the user can connect
with static_proxy.connect(user="pinocchio", password="magic", options="project=irrelevant"):
pass