[proxy] Propagate some errors to user (#1329)

* [proxy] Propagate most errors to user

This change enables propagation of most errors to the user
(e.g. auth and connectivity errors). Some of them will be
stripped of sensitive information.

As a side effect, most occurrences of `anyhow::Error` were
replaced with concrete error types.

* [proxy] Box weighty errors
This commit is contained in:
Dmitry Ivanov
2022-03-16 21:20:04 +03:00
committed by GitHub
parent 9c1a9a1d9f
commit 705f51db27
14 changed files with 481 additions and 166 deletions

2
Cargo.lock generated
View File

@@ -1739,6 +1739,7 @@ dependencies = [
"anyhow",
"bytes",
"clap 3.0.14",
"fail",
"futures",
"hashbrown 0.11.2",
"hex",
@@ -1754,6 +1755,7 @@ dependencies = [
"scopeguard",
"serde",
"serde_json",
"thiserror",
"tokio",
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"tokio-postgres-rustls",

View File

@@ -7,6 +7,7 @@ edition = "2021"
anyhow = "1.0"
bytes = { version = "1.0.1", features = ['serde'] }
clap = "3.0"
fail = "0.5.0"
futures = "0.3.13"
hashbrown = "0.11.2"
hex = "0.4.3"
@@ -21,6 +22,7 @@ rustls = "0.19.1"
scopeguard = "1.1.0"
serde = "1"
serde_json = "1"
thiserror = "1.0"
tokio = { version = "1.11", features = ["macros"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
tokio-rustls = "0.22.0"

View File

@@ -1,11 +1,79 @@
use crate::compute::DatabaseInfo;
use crate::config::ProxyConfig;
use crate::cplane_api::{self, CPlaneApi};
use crate::error::UserFacingError;
use crate::stream::PqStream;
use anyhow::{anyhow, bail, Context};
use crate::waiters;
use std::collections::HashMap;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
/// Common authentication error.
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
#[error(transparent)]
Console(#[from] cplane_api::AuthError),
/// For passwords that couldn't be processed by [`parse_password`].
#[error("Malformed password message")]
MalformedPassword,
/// Errors produced by [`PqStream`].
#[error(transparent)]
Io(#[from] std::io::Error),
}
impl AuthErrorImpl {
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthErrorImpl::Console(cplane_api::AuthError::auth_failed(msg))
}
}
impl From<waiters::RegisterError> for AuthErrorImpl {
fn from(e: waiters::RegisterError) -> Self {
AuthErrorImpl::Console(cplane_api::AuthError::from(e))
}
}
impl From<waiters::WaitError> for AuthErrorImpl {
fn from(e: waiters::WaitError) -> Self {
AuthErrorImpl::Console(cplane_api::AuthError::from(e))
}
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct AuthError(Box<AuthErrorImpl>);
impl<T> From<T> for AuthError
where
AuthErrorImpl: From<T>,
{
fn from(e: T) -> Self {
AuthError(Box::new(e.into()))
}
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
use AuthErrorImpl::*;
match self.0.as_ref() {
Console(e) => e.to_string_client(),
MalformedPassword => self.to_string(),
_ => "Internal error".to_string(),
}
}
}
#[derive(Debug, Error)]
pub enum ClientCredsParseError {
#[error("Parameter `{0}` is missing in startup packet")]
MissingKey(&'static str),
}
impl UserFacingError for ClientCredsParseError {}
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
@@ -15,13 +83,13 @@ pub struct ClientCredentials {
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = anyhow::Error;
type Error = ClientCredsParseError;
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
let mut get_param = |key| {
value
.remove(key)
.with_context(|| format!("{} is missing in startup packet", key))
.ok_or(ClientCredsParseError::MissingKey(key))
};
let user = get_param("user")?;
@@ -37,10 +105,14 @@ impl ClientCredentials {
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<DatabaseInfo> {
) -> Result<DatabaseInfo, AuthError> {
fail::fail_point!("proxy-authenticate", |_| {
Err(AuthError::auth_failed("failpoint triggered"))
});
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
let db_info = match &config.router_config {
match &config.router_config {
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
Dynamic(Mixed) => {
if self.user.ends_with("@zenith") {
@@ -51,9 +123,7 @@ impl ClientCredentials {
}
Dynamic(Password) => handle_existing_user(config, client, self).await,
Dynamic(Link) => handle_new_user(config, client).await,
};
db_info.context("failed to authenticate client")
}
}
}
@@ -66,18 +136,14 @@ async fn handle_static(
port: u16,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> anyhow::Result<DatabaseInfo> {
) -> Result<DatabaseInfo, AuthError> {
client
.write_message(&Be::AuthenticationCleartextPassword)
.await?;
// Read client's password bytes
let msg = match client.read_message().await? {
Fe::PasswordMessage(msg) => msg,
bad => bail!("unexpected message type: {:?}", bad),
};
let cleartext_password = std::str::from_utf8(&msg)?.split('\0').next().unwrap();
let msg = client.read_password_message().await?;
let cleartext_password = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
let db_info = DatabaseInfo {
host,
@@ -98,7 +164,7 @@ async fn handle_existing_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> anyhow::Result<DatabaseInfo> {
) -> Result<DatabaseInfo, AuthError> {
let psql_session_id = new_psql_session_id();
let md5_salt = rand::random();
@@ -107,18 +173,12 @@ async fn handle_existing_user(
.await?;
// Read client's password hash
let msg = match client.read_message().await? {
Fe::PasswordMessage(msg) => msg,
bad => bail!("unexpected message type: {:?}", bad),
};
let msg = client.read_password_message().await?;
let md5_response = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
let (_trailing_null, md5_response) = msg
.split_last()
.ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&config.auth_endpoint);
let cplane = CPlaneApi::new(config.auth_endpoint.clone());
let db_info = cplane
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
.authenticate_proxy_client(creds, md5_response, &md5_salt, &psql_session_id)
.await?;
client
@@ -131,7 +191,7 @@ async fn handle_existing_user(
async fn handle_new_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<DatabaseInfo> {
) -> Result<DatabaseInfo, AuthError> {
let psql_session_id = new_psql_session_id();
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
@@ -143,8 +203,8 @@ async fn handle_new_user(
.write_message(&Be::NoticeResponse(greeting))
.await?;
// Wait for web console response
waiter.await?.map_err(|e| anyhow!(e))
// Wait for web console response (see `mgmt`)
waiter.await?.map_err(AuthErrorImpl::auth_failed)
})
.await?;
@@ -153,6 +213,10 @@ async fn handle_new_user(
Ok(db_info)
}
fn parse_password(bytes: &[u8]) -> Option<&str> {
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
format!(
concat![

View File

@@ -6,7 +6,7 @@ use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use zenith_utils::pq_proto::CancelKeyData;
/// Enables serving CancelRequests.
/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);

View File

@@ -1,6 +1,27 @@
use anyhow::Context;
use crate::cancellation::CancelClosure;
use crate::error::UserFacingError;
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, ToSocketAddrs};
use std::io;
use std::net::SocketAddr;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
#[derive(Debug, Error)]
pub enum ConnectionError {
/// This error doesn't seem to reveal any secrets; for instance,
/// [`tokio_postgres::error::Kind`] doesn't contain ip addresses and such.
#[error("Failed to connect to the compute node: {0}")]
Postgres(#[from] tokio_postgres::Error),
#[error("Failed to connect to the compute node")]
FailedToConnectToCompute,
#[error("Failed to fetch compute node version")]
FailedToFetchPgVersion,
}
impl UserFacingError for ConnectionError {}
/// Compute node connection params.
#[derive(Serialize, Deserialize, Debug, Default)]
@@ -12,14 +33,38 @@ pub struct DatabaseInfo {
pub password: Option<String>,
}
/// PostgreSQL version as [`String`].
pub type Version = String;
impl DatabaseInfo {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.context("cannot resolve at least one SocketAddr")
let socket = TcpStream::connect(host_port).await?;
let socket_addr = socket.peer_addr()?;
Ok((socket_addr, socket))
}
/// Connect to a corresponding compute node.
pub async fn connect(self) -> Result<(TcpStream, Version, CancelClosure), ConnectionError> {
let (socket_addr, mut socket) = self
.connect_raw()
.await
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
// TODO: establish a secure connection to the DB
let (client, conn) = tokio_postgres::Config::from(self)
.connect_raw(&mut socket, NoTls)
.await?;
let version = conn
.parameter("server_version")
.ok_or(ConnectionError::FailedToFetchPgVersion)?
.into();
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
Ok((socket, version, cancel_closure))
}
}

View File

@@ -1,4 +1,4 @@
use anyhow::{anyhow, ensure, Context};
use anyhow::{anyhow, bail, ensure, Context};
use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig};
use std::net::SocketAddr;
use std::str::FromStr;
@@ -29,7 +29,7 @@ impl FromStr for ClientAuthMethod {
"password" => Ok(Password),
"link" => Ok(Link),
"mixed" => Ok(Mixed),
_ => Err(anyhow::anyhow!("Invlid option for router")),
_ => bail!("Invalid option for router: `{}`", s),
}
}
}
@@ -53,7 +53,7 @@ pub struct ProxyConfig {
pub redirect_uri: String,
/// control plane address where we would check auth.
pub auth_endpoint: String,
pub auth_endpoint: reqwest::Url,
pub tls_config: Option<TlsConfig>,
}

View File

@@ -1,52 +1,113 @@
use crate::auth::ClientCredentials;
use crate::compute::DatabaseInfo;
use crate::waiters::{Waiter, Waiters};
use anyhow::{anyhow, bail};
use crate::error::UserFacingError;
use crate::mgmt;
use crate::waiters::{self, Waiter, Waiters};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use thiserror::Error;
lazy_static! {
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
static ref CPLANE_WAITERS: Waiters<mgmt::ComputeReady> = Default::default();
}
/// Give caller an opportunity to wait for cplane's reply.
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
) -> Result<T, E>
where
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
R: std::future::Future<Output = anyhow::Result<T>>,
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
f(waiter).await
action(waiter).await
}
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
pub fn notify(
psql_session_id: &str,
msg: Result<DatabaseInfo, String>,
) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Zenith console API wrapper.
pub struct CPlaneApi<'a> {
auth_endpoint: &'a str,
pub struct CPlaneApi {
auth_endpoint: reqwest::Url,
}
impl<'a> CPlaneApi<'a> {
pub fn new(auth_endpoint: &'a str) -> Self {
impl CPlaneApi {
pub fn new(auth_endpoint: reqwest::Url) -> Self {
Self { auth_endpoint }
}
}
impl CPlaneApi<'_> {
pub async fn authenticate_proxy_request(
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
/// HTTP status (other than 200) returned by the console.
#[error("Console responded with an HTTP status: {0}")]
HttpStatus(reqwest::StatusCode),
#[error("Console responded with a malformed JSON: {0}")]
MalformedResponse(#[from] serde_json::Error),
#[error(transparent)]
Transport(#[from] reqwest::Error),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
#[error(transparent)]
WaiterWait(#[from] waiters::WaitError),
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
/// Smart constructor for authentication error reported by `mgmt`.
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
}
}
impl<T> From<T> for AuthError
where
AuthErrorImpl: From<T>,
{
fn from(e: T) -> Self {
AuthError(Box::new(e.into()))
}
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
use AuthErrorImpl::*;
match self.0.as_ref() {
AuthFailed(_) | HttpStatus(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}
}
impl CPlaneApi {
pub async fn authenticate_proxy_client(
&self,
creds: ClientCredentials,
md5_response: &[u8],
md5_response: &str,
salt: &[u8; 4],
psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
) -> Result<DatabaseInfo, AuthError> {
let mut url = self.auth_endpoint.clone();
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("md5response", std::str::from_utf8(md5_response)?)
.append_pair("md5response", md5_response)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
@@ -55,18 +116,20 @@ impl CPlaneApi<'_> {
// TODO: leverage `reqwest::Client` to reuse connections
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
bail!("Auth failed: {}", resp.status())
return Err(AuthErrorImpl::HttpStatus(resp.status()).into());
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
println!("got auth info: #{:?}", auth_info);
use ProxyAuthResponse::*;
match auth_info {
Ready { conn_info } => Ok(conn_info),
Error { error } => bail!(error),
NotReady { .. } => waiter.await?.map_err(|e| anyhow!(e)),
}
let db_info = match auth_info {
Ready { conn_info } => conn_info,
Error { error } => return Err(AuthErrorImpl::AuthFailed(error).into()),
NotReady { .. } => waiter.await?.map_err(AuthErrorImpl::AuthFailed)?,
};
Ok(db_info)
})
.await
}

17
proxy/src/error.rs Normal file
View File

@@ -0,0 +1,17 @@
/// Marks errors that may be safely shown to a client.
/// This trait can be seen as a specialized version of [`ToString`].
///
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
/// is way too convenient and tends to proliferate all across the codebase,
/// ultimately leading to accidental leaks of sensitive data.
pub trait UserFacingError: ToString {
/// Format the error for client, stripping all sensitive info.
///
/// Although this might be a no-op for many types, it's highly
/// recommended to override the default impl in case error type
/// contains anything sensitive: various IDs, IP addresses etc.
#[inline(always)]
fn to_string_client(&self) -> String {
self.to_string()
}
}

View File

@@ -7,7 +7,7 @@ use zenith_utils::http::json::json_response;
use zenith_utils::http::{RouterBuilder, RouterService};
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
Ok(json_response(StatusCode::OK, "")?)
json_response(StatusCode::OK, "")
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {

View File

@@ -20,13 +20,14 @@ mod cancellation;
mod compute;
mod config;
mod cplane_api;
mod error;
mod http;
mod mgmt;
mod proxy;
mod stream;
mod waiters;
/// Flattens Result<Result<T>> into Result<T>.
/// Flattens `Result<Result<T>>` into `Result<T>`.
async fn flatten_err(
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
) -> anyhow::Result<()> {

View File

@@ -79,6 +79,18 @@ enum PsqlSessionResult {
Failure(String),
}
/// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = Result<DatabaseInfo, String>;
impl PsqlSessionResult {
fn into_compute_ready(self) -> ComputeReady {
match self {
Self::Success(db_info) => Ok(db_info),
Self::Failure(message) => Err(message),
}
}
}
impl postgres_backend::Handler for MgmtHandler {
fn process_query(
&mut self,
@@ -99,13 +111,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
use PsqlSessionResult::*;
let msg = match resp.result {
Success(db_info) => Ok(db_info),
Failure(message) => Err(message),
};
match cplane_api::notify(&resp.session_id, msg) {
match cplane_api::notify(&resp.session_id, resp.result.into_compute_ready()) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

View File

@@ -1,17 +1,18 @@
use crate::auth;
use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo;
use crate::cancellation::{self, CancelMap};
use crate::config::{ProxyConfig, TlsConfig};
use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use futures::TryFutureExt;
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter};
use zenith_utils::pq_proto::{BeMessage as Be, *};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
lazy_static! {
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_accepted"),
@@ -30,6 +31,7 @@ lazy_static! {
.unwrap();
}
/// A small combinator for pluggable error logging.
async fn log_error<R, F>(future: F) -> F::Output
where
F: std::future::Future<Output = anyhow::Result<R>>,
@@ -76,20 +78,21 @@ async fn handle_client(
}
let tls = config.tls_config.clone();
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
cancel_map
.with_session(|session| async {
connect_client_to_db(config, session, client, creds).await
})
.await?;
}
let (stream, creds) = match handshake(stream, tls, cancel_map).await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
Ok(())
let client = Client::new(stream, creds);
cancel_map
.with_session(|session| client.connect_to_db(config, session))
.await
}
/// Handle a connection from one client.
/// For better testing experience, `stream` can be
/// any object satisfying the traits.
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to updgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<TlsConfig>,
@@ -119,7 +122,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
}
}
_ => bail!("protocol violation"),
_ => bail!(ERR_PROTO_VIOLATION),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
@@ -128,18 +131,21 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!("protocol violation"),
_ => bail!(ERR_PROTO_VIOLATION),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
let msg = "connection is insecure (try using `sslmode=require`)";
stream.write_message(&Be::ErrorResponse(msg)).await?;
bail!(msg);
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
break Ok(Some((stream, params.try_into()?)));
// Here and forth: `or_else` demands that we use a future here
let creds = async { params.try_into() }
.or_else(|e| stream.throw_error(e))
.await?;
break Ok(Some((stream, creds)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
@@ -150,58 +156,60 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
async fn connect_client_to_db(
config: &ProxyConfig,
session: cancellation::Session<'_>,
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
/// Thin connection context.
struct Client<S> {
/// The underlying libpq protocol stream.
stream: PqStream<S>,
/// Client credentials that we care about.
creds: auth::ClientCredentials,
) -> anyhow::Result<()> {
let db_info = creds.authenticate(config, &mut client).await?;
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
client
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&version),
))?
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
// This function will be called for writes to either direction.
fn inc_proxied(cnt: usize) {
// Consider inventing something more sophisticated
// if this ever becomes a bottleneck (cacheline bouncing).
NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64);
}
let mut db = MetricsStream::new(db, inc_proxied);
let mut client = MetricsStream::new(client.into_inner(), inc_proxied);
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
Ok(())
}
/// Connect to a corresponding compute node.
async fn connect_to_db(
db_info: DatabaseInfo,
) -> anyhow::Result<(TcpStream, String, CancelClosure)> {
// TODO: establish a secure connection to the DB
let socket_addr = db_info.socket_addr()?;
let mut socket = TcpStream::connect(socket_addr).await?;
impl<S> Client<S> {
/// Construct a new connection context.
fn new(stream: PqStream<S>, creds: auth::ClientCredentials) -> Self {
Self { stream, creds }
}
}
let (client, conn) = tokio_postgres::Config::from(db_info)
.connect_raw(&mut socket, NoTls)
.await?;
impl<S: AsyncRead + AsyncWrite + Unpin> Client<S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(
self,
config: &ProxyConfig,
session: cancellation::Session<'_>,
) -> anyhow::Result<()> {
let Self { mut stream, creds } = self;
let version = conn
.parameter("server_version")
.context("failed to fetch postgres server version")?
.into();
// Authenticate and connect to a compute node.
let auth = creds.authenticate(config, &mut stream).await;
let db_info = async { auth }.or_else(|e| stream.throw_error(e)).await?;
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
let (db, version, cancel_closure) =
db_info.connect().or_else(|e| stream.throw_error(e)).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
Ok((socket, version, cancel_closure))
stream
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&version),
))?
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
/// This function will be called for writes to either direction.
fn inc_proxied(cnt: usize) {
// Consider inventing something more sophisticated
// if this ever becomes a bottleneck (cacheline bouncing).
NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64);
}
// Starting from here we only proxy the client's traffic.
let mut db = MetricsStream::new(db, inc_proxied);
let mut client = MetricsStream::new(stream.into_inner(), inc_proxied);
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
Ok(())
}
}
#[cfg(test)]
@@ -210,7 +218,7 @@ mod tests {
use tokio::io::DuplexStream;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::MakeRustlsConnect;
async fn dummy_proxy(
@@ -264,7 +272,7 @@ mod tests {
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
tokio_postgres::Config::new()
let client_err = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Disable)
@@ -273,11 +281,15 @@ mod tests {
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
proxy
assert!(client_err.to_string().contains(ERR_INSECURE_CONNECTION));
let server_err = proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
assert!(client_err.to_string().contains(&server_err.to_string()));
Ok(())
}
@@ -329,4 +341,30 @@ mod tests {
proxy.await?
}
#[tokio::test]
async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None));
let client_err = tokio_postgres::Config::new()
.ssl_mode(SslMode::Disable)
.connect_raw(server, NoTls)
.await
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
// TODO: this is ugly, but `format!` won't allow us to extract fmt string
assert!(client_err.to_string().contains("missing in startup packet"));
let server_err = proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
assert!(client_err.to_string().contains(&server_err.to_string()));
Ok(())
}
}

View File

@@ -1,10 +1,12 @@
use anyhow::Context;
use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
@@ -35,38 +37,63 @@ impl<S> PqStream<S> {
self.stream
}
/// Get a reference to the underlying stream.
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
}
fn err_connection() -> io::Error {
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
}
// TODO: change error type of `FeMessage::read_fut`
fn from_anyhow(e: anyhow::Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> anyhow::Result<FeStartupPacket> {
match FeStartupPacket::read_fut(&mut self.stream).await? {
Some(FeMessage::StartupPacket(packet)) => Ok(packet),
None => anyhow::bail!("connection is lost"),
other => anyhow::bail!("bad message type: {:?}", other),
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
let msg = FeStartupPacket::read_fut(&mut self.stream)
.await
.map_err(from_anyhow)?
.ok_or_else(err_connection)?;
match msg {
FeMessage::StartupPacket(packet) => Ok(packet),
_ => panic!("unreachable state"),
}
}
pub async fn read_message(&mut self) -> anyhow::Result<FeMessage> {
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
match self.read_message().await? {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {:?}", bad),
)),
}
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
FeMessage::read_fut(&mut self.stream)
.await?
.context("connection is lost")
.await
.map_err(from_anyhow)?
.ok_or_else(err_connection)
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buffer, message)?;
Ok(self)
}
/// Write the message into an internal buffer and flush it.
pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
@@ -79,6 +106,25 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
self.stream.flush().await?;
Ok(self)
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
// This method exists due to `&str` not implementing `Into<anyhow::Error>`
self.write_message(&BeMessage::ErrorResponse(error)).await?;
bail!(error)
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
where
E: UserFacingError + Into<anyhow::Error>,
{
let msg = error.to_string_client();
self.write_message(&BeMessage::ErrorResponse(&msg)).await?;
bail!(error)
}
}
pin_project! {
@@ -101,15 +147,25 @@ impl<S> Stream<S> {
}
}
#[derive(Debug, Error)]
#[error("Can't upgrade TLS stream")]
pub enum StreamUpgradeError {
#[error("Bad state reached: can't upgrade TLS stream")]
AlreadyTls,
#[error("Can't upgrade stream: IO error: {0}")]
Io(#[from] io::Error),
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> anyhow::Result<Self> {
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
match self {
Stream::Raw { raw } => {
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
Ok(Stream::Tls { tls })
}
Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}
}

View File

@@ -1,11 +1,32 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task;
use thiserror::Error;
use tokio::sync::oneshot;
#[derive(Debug, Error)]
pub enum RegisterError {
#[error("Waiter `{0}` already registered")]
Occupied(String),
}
#[derive(Debug, Error)]
pub enum NotifyError {
#[error("Notify failed: waiter `{0}` not registered")]
NotFound(String),
#[error("Notify failed: channel hangup")]
Hangup,
}
#[derive(Debug, Error)]
pub enum WaitError {
#[error("Wait failed: channel hangup")]
Hangup,
}
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
impl<T> Default for Waiters<T> {
@@ -15,13 +36,13 @@ impl<T> Default for Waiters<T> {
}
impl<T> Waiters<T> {
pub fn register(&self, key: String) -> anyhow::Result<Waiter<T>> {
pub fn register(&self, key: String) -> Result<Waiter<T>, RegisterError> {
let (tx, rx) = oneshot::channel();
self.0
.lock()
.try_insert(key.clone(), tx)
.map_err(|_| anyhow!("waiter already registered"))?;
.map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?;
Ok(Waiter {
receiver: rx,
@@ -32,7 +53,7 @@ impl<T> Waiters<T> {
})
}
pub fn notify(&self, key: &str, value: T) -> anyhow::Result<()>
pub fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
where
T: Send + Sync,
{
@@ -40,9 +61,9 @@ impl<T> Waiters<T> {
.0
.lock()
.remove(key)
.with_context(|| format!("key {} not found", key))?;
.ok_or_else(|| NotifyError::NotFound(key.to_string()))?;
tx.send(value).map_err(|_| anyhow!("waiter channel hangup"))
tx.send(value).map_err(|_| NotifyError::Hangup)
}
}
@@ -66,13 +87,13 @@ pin_project! {
}
impl<T> std::future::Future for Waiter<'_, T> {
type Output = anyhow::Result<T>;
type Output = Result<T, WaitError>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project()
.receiver
.poll(cx)
.map_err(|_| anyhow!("channel hangup"))
.map_err(|_| WaitError::Hangup)
}
}