diff --git a/Cargo.lock b/Cargo.lock index b1ebe6c07a..750ac0edc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1739,6 +1739,7 @@ dependencies = [ "anyhow", "bytes", "clap 3.0.14", + "fail", "futures", "hashbrown 0.11.2", "hex", @@ -1754,6 +1755,7 @@ dependencies = [ "scopeguard", "serde", "serde_json", + "thiserror", "tokio", "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "tokio-postgres-rustls", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index d8d5cbe5bf..dda018a1d8 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" anyhow = "1.0" bytes = { version = "1.0.1", features = ['serde'] } clap = "3.0" +fail = "0.5.0" futures = "0.3.13" hashbrown = "0.11.2" hex = "0.4.3" @@ -21,6 +22,7 @@ rustls = "0.19.1" scopeguard = "1.1.0" serde = "1" serde_json = "1" +thiserror = "1.0" tokio = { version = "1.11", features = ["macros"] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } tokio-rustls = "0.22.0" diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index a5bdaeaeca..5e6357fe80 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,11 +1,79 @@ use crate::compute::DatabaseInfo; use crate::config::ProxyConfig; use crate::cplane_api::{self, CPlaneApi}; +use crate::error::UserFacingError; use crate::stream::PqStream; -use anyhow::{anyhow, bail, Context}; +use crate::waiters; use std::collections::HashMap; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe}; +use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; + +/// Common authentication error. +#[derive(Debug, Error)] +pub enum AuthErrorImpl { + /// Authentication error reported by the console. + #[error(transparent)] + Console(#[from] cplane_api::AuthError), + + /// For passwords that couldn't be processed by [`parse_password`]. + #[error("Malformed password message")] + MalformedPassword, + + /// Errors produced by [`PqStream`]. + #[error(transparent)] + Io(#[from] std::io::Error), +} + +impl AuthErrorImpl { + pub fn auth_failed(msg: impl Into) -> Self { + AuthErrorImpl::Console(cplane_api::AuthError::auth_failed(msg)) + } +} + +impl From for AuthErrorImpl { + fn from(e: waiters::RegisterError) -> Self { + AuthErrorImpl::Console(cplane_api::AuthError::from(e)) + } +} + +impl From for AuthErrorImpl { + fn from(e: waiters::WaitError) -> Self { + AuthErrorImpl::Console(cplane_api::AuthError::from(e)) + } +} + +#[derive(Debug, Error)] +#[error(transparent)] +pub struct AuthError(Box); + +impl From for AuthError +where + AuthErrorImpl: From, +{ + fn from(e: T) -> Self { + AuthError(Box::new(e.into())) + } +} + +impl UserFacingError for AuthError { + fn to_string_client(&self) -> String { + use AuthErrorImpl::*; + match self.0.as_ref() { + Console(e) => e.to_string_client(), + MalformedPassword => self.to_string(), + _ => "Internal error".to_string(), + } + } +} + +#[derive(Debug, Error)] +pub enum ClientCredsParseError { + #[error("Parameter `{0}` is missing in startup packet")] + MissingKey(&'static str), +} + +impl UserFacingError for ClientCredsParseError {} /// Various client credentials which we use for authentication. #[derive(Debug, PartialEq, Eq)] @@ -15,13 +83,13 @@ pub struct ClientCredentials { } impl TryFrom> for ClientCredentials { - type Error = anyhow::Error; + type Error = ClientCredsParseError; fn try_from(mut value: HashMap) -> Result { let mut get_param = |key| { value .remove(key) - .with_context(|| format!("{} is missing in startup packet", key)) + .ok_or(ClientCredsParseError::MissingKey(key)) }; let user = get_param("user")?; @@ -37,10 +105,14 @@ impl ClientCredentials { self, config: &ProxyConfig, client: &mut PqStream, - ) -> anyhow::Result { + ) -> Result { + fail::fail_point!("proxy-authenticate", |_| { + Err(AuthError::auth_failed("failpoint triggered")) + }); + use crate::config::ClientAuthMethod::*; use crate::config::RouterConfig::*; - let db_info = match &config.router_config { + match &config.router_config { Static { host, port } => handle_static(host.clone(), *port, client, self).await, Dynamic(Mixed) => { if self.user.ends_with("@zenith") { @@ -51,9 +123,7 @@ impl ClientCredentials { } Dynamic(Password) => handle_existing_user(config, client, self).await, Dynamic(Link) => handle_new_user(config, client).await, - }; - - db_info.context("failed to authenticate client") + } } } @@ -66,18 +136,14 @@ async fn handle_static( port: u16, client: &mut PqStream, creds: ClientCredentials, -) -> anyhow::Result { +) -> Result { client .write_message(&Be::AuthenticationCleartextPassword) .await?; // Read client's password bytes - let msg = match client.read_message().await? { - Fe::PasswordMessage(msg) => msg, - bad => bail!("unexpected message type: {:?}", bad), - }; - - let cleartext_password = std::str::from_utf8(&msg)?.split('\0').next().unwrap(); + let msg = client.read_password_message().await?; + let cleartext_password = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; let db_info = DatabaseInfo { host, @@ -98,7 +164,7 @@ async fn handle_existing_user( config: &ProxyConfig, client: &mut PqStream, creds: ClientCredentials, -) -> anyhow::Result { +) -> Result { let psql_session_id = new_psql_session_id(); let md5_salt = rand::random(); @@ -107,18 +173,12 @@ async fn handle_existing_user( .await?; // Read client's password hash - let msg = match client.read_message().await? { - Fe::PasswordMessage(msg) => msg, - bad => bail!("unexpected message type: {:?}", bad), - }; + let msg = client.read_password_message().await?; + let md5_response = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; - let (_trailing_null, md5_response) = msg - .split_last() - .ok_or_else(|| anyhow!("unexpected password message"))?; - - let cplane = CPlaneApi::new(&config.auth_endpoint); + let cplane = CPlaneApi::new(config.auth_endpoint.clone()); let db_info = cplane - .authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id) + .authenticate_proxy_client(creds, md5_response, &md5_salt, &psql_session_id) .await?; client @@ -131,7 +191,7 @@ async fn handle_existing_user( async fn handle_new_user( config: &ProxyConfig, client: &mut PqStream, -) -> anyhow::Result { +) -> Result { let psql_session_id = new_psql_session_id(); let greeting = hello_message(&config.redirect_uri, &psql_session_id); @@ -143,8 +203,8 @@ async fn handle_new_user( .write_message(&Be::NoticeResponse(greeting)) .await?; - // Wait for web console response - waiter.await?.map_err(|e| anyhow!(e)) + // Wait for web console response (see `mgmt`) + waiter.await?.map_err(AuthErrorImpl::auth_failed) }) .await?; @@ -153,6 +213,10 @@ async fn handle_new_user( Ok(db_info) } +fn parse_password(bytes: &[u8]) -> Option<&str> { + std::str::from_utf8(bytes).ok()?.strip_suffix('\0') +} + fn hello_message(redirect_uri: &str, session_id: &str) -> String { format!( concat![ diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index c1a7e81be9..07d3bcc71a 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -6,7 +6,7 @@ use tokio::net::TcpStream; use tokio_postgres::{CancelToken, NoTls}; use zenith_utils::pq_proto::CancelKeyData; -/// Enables serving CancelRequests. +/// Enables serving `CancelRequest`s. #[derive(Default)] pub struct CancelMap(Mutex>>); diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 7c294bd488..64ce5d0a5a 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,6 +1,27 @@ -use anyhow::Context; +use crate::cancellation::CancelClosure; +use crate::error::UserFacingError; use serde::{Deserialize, Serialize}; -use std::net::{SocketAddr, ToSocketAddrs}; +use std::io; +use std::net::SocketAddr; +use thiserror::Error; +use tokio::net::TcpStream; +use tokio_postgres::NoTls; + +#[derive(Debug, Error)] +pub enum ConnectionError { + /// This error doesn't seem to reveal any secrets; for instance, + /// [`tokio_postgres::error::Kind`] doesn't contain ip addresses and such. + #[error("Failed to connect to the compute node: {0}")] + Postgres(#[from] tokio_postgres::Error), + + #[error("Failed to connect to the compute node")] + FailedToConnectToCompute, + + #[error("Failed to fetch compute node version")] + FailedToFetchPgVersion, +} + +impl UserFacingError for ConnectionError {} /// Compute node connection params. #[derive(Serialize, Deserialize, Debug, Default)] @@ -12,14 +33,38 @@ pub struct DatabaseInfo { pub password: Option, } +/// PostgreSQL version as [`String`]. +pub type Version = String; + impl DatabaseInfo { - pub fn socket_addr(&self) -> anyhow::Result { + async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> { let host_port = format!("{}:{}", self.host, self.port); - host_port - .to_socket_addrs() - .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? - .next() - .context("cannot resolve at least one SocketAddr") + let socket = TcpStream::connect(host_port).await?; + let socket_addr = socket.peer_addr()?; + + Ok((socket_addr, socket)) + } + + /// Connect to a corresponding compute node. + pub async fn connect(self) -> Result<(TcpStream, Version, CancelClosure), ConnectionError> { + let (socket_addr, mut socket) = self + .connect_raw() + .await + .map_err(|_| ConnectionError::FailedToConnectToCompute)?; + + // TODO: establish a secure connection to the DB + let (client, conn) = tokio_postgres::Config::from(self) + .connect_raw(&mut socket, NoTls) + .await?; + + let version = conn + .parameter("server_version") + .ok_or(ConnectionError::FailedToFetchPgVersion)? + .into(); + + let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token()); + + Ok((socket, version, cancel_closure)) } } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 9ab64db795..077ff02898 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, ensure, Context}; +use anyhow::{anyhow, bail, ensure, Context}; use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig}; use std::net::SocketAddr; use std::str::FromStr; @@ -29,7 +29,7 @@ impl FromStr for ClientAuthMethod { "password" => Ok(Password), "link" => Ok(Link), "mixed" => Ok(Mixed), - _ => Err(anyhow::anyhow!("Invlid option for router")), + _ => bail!("Invalid option for router: `{}`", s), } } } @@ -53,7 +53,7 @@ pub struct ProxyConfig { pub redirect_uri: String, /// control plane address where we would check auth. - pub auth_endpoint: String, + pub auth_endpoint: reqwest::Url, pub tls_config: Option, } diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index 187809717f..21fce79df3 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -1,52 +1,113 @@ use crate::auth::ClientCredentials; use crate::compute::DatabaseInfo; -use crate::waiters::{Waiter, Waiters}; -use anyhow::{anyhow, bail}; +use crate::error::UserFacingError; +use crate::mgmt; +use crate::waiters::{self, Waiter, Waiters}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use thiserror::Error; lazy_static! { - static ref CPLANE_WAITERS: Waiters> = Default::default(); + static ref CPLANE_WAITERS: Waiters = Default::default(); } /// Give caller an opportunity to wait for cplane's reply. -pub async fn with_waiter(psql_session_id: impl Into, f: F) -> anyhow::Result +pub async fn with_waiter( + psql_session_id: impl Into, + action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R, +) -> Result where - F: FnOnce(Waiter<'static, Result>) -> R, - R: std::future::Future>, + R: std::future::Future>, + E: From, { let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; - f(waiter).await + action(waiter).await } -pub fn notify(psql_session_id: &str, msg: Result) -> anyhow::Result<()> { +pub fn notify( + psql_session_id: &str, + msg: Result, +) -> Result<(), waiters::NotifyError> { CPLANE_WAITERS.notify(psql_session_id, msg) } /// Zenith console API wrapper. -pub struct CPlaneApi<'a> { - auth_endpoint: &'a str, +pub struct CPlaneApi { + auth_endpoint: reqwest::Url, } -impl<'a> CPlaneApi<'a> { - pub fn new(auth_endpoint: &'a str) -> Self { +impl CPlaneApi { + pub fn new(auth_endpoint: reqwest::Url) -> Self { Self { auth_endpoint } } } -impl CPlaneApi<'_> { - pub async fn authenticate_proxy_request( +#[derive(Debug, Error)] +pub enum AuthErrorImpl { + /// Authentication error reported by the console. + #[error("Authentication failed: {0}")] + AuthFailed(String), + + /// HTTP status (other than 200) returned by the console. + #[error("Console responded with an HTTP status: {0}")] + HttpStatus(reqwest::StatusCode), + + #[error("Console responded with a malformed JSON: {0}")] + MalformedResponse(#[from] serde_json::Error), + + #[error(transparent)] + Transport(#[from] reqwest::Error), + + #[error(transparent)] + WaiterRegister(#[from] waiters::RegisterError), + + #[error(transparent)] + WaiterWait(#[from] waiters::WaitError), +} + +#[derive(Debug, Error)] +#[error(transparent)] +pub struct AuthError(Box); + +impl AuthError { + /// Smart constructor for authentication error reported by `mgmt`. + pub fn auth_failed(msg: impl Into) -> Self { + AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into()))) + } +} + +impl From for AuthError +where + AuthErrorImpl: From, +{ + fn from(e: T) -> Self { + AuthError(Box::new(e.into())) + } +} + +impl UserFacingError for AuthError { + fn to_string_client(&self) -> String { + use AuthErrorImpl::*; + match self.0.as_ref() { + AuthFailed(_) | HttpStatus(_) => self.to_string(), + _ => "Internal error".to_string(), + } + } +} + +impl CPlaneApi { + pub async fn authenticate_proxy_client( &self, creds: ClientCredentials, - md5_response: &[u8], + md5_response: &str, salt: &[u8; 4], psql_session_id: &str, - ) -> anyhow::Result { - let mut url = reqwest::Url::parse(self.auth_endpoint)?; + ) -> Result { + let mut url = self.auth_endpoint.clone(); url.query_pairs_mut() .append_pair("login", &creds.user) .append_pair("database", &creds.dbname) - .append_pair("md5response", std::str::from_utf8(md5_response)?) + .append_pair("md5response", md5_response) .append_pair("salt", &hex::encode(salt)) .append_pair("psql_session_id", psql_session_id); @@ -55,18 +116,20 @@ impl CPlaneApi<'_> { // TODO: leverage `reqwest::Client` to reuse connections let resp = reqwest::get(url).await?; if !resp.status().is_success() { - bail!("Auth failed: {}", resp.status()) + return Err(AuthErrorImpl::HttpStatus(resp.status()).into()); } let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?; println!("got auth info: #{:?}", auth_info); use ProxyAuthResponse::*; - match auth_info { - Ready { conn_info } => Ok(conn_info), - Error { error } => bail!(error), - NotReady { .. } => waiter.await?.map_err(|e| anyhow!(e)), - } + let db_info = match auth_info { + Ready { conn_info } => conn_info, + Error { error } => return Err(AuthErrorImpl::AuthFailed(error).into()), + NotReady { .. } => waiter.await?.map_err(AuthErrorImpl::AuthFailed)?, + }; + + Ok(db_info) }) .await } diff --git a/proxy/src/error.rs b/proxy/src/error.rs new file mode 100644 index 0000000000..e98e553f83 --- /dev/null +++ b/proxy/src/error.rs @@ -0,0 +1,17 @@ +/// Marks errors that may be safely shown to a client. +/// This trait can be seen as a specialized version of [`ToString`]. +/// +/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it +/// is way too convenient and tends to proliferate all across the codebase, +/// ultimately leading to accidental leaks of sensitive data. +pub trait UserFacingError: ToString { + /// Format the error for client, stripping all sensitive info. + /// + /// Although this might be a no-op for many types, it's highly + /// recommended to override the default impl in case error type + /// contains anything sensitive: various IDs, IP addresses etc. + #[inline(always)] + fn to_string_client(&self) -> String { + self.to_string() + } +} diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 0b693d88dd..33d134678f 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -7,7 +7,7 @@ use zenith_utils::http::json::json_response; use zenith_utils::http::{RouterBuilder, RouterService}; async fn status_handler(_: Request) -> Result, ApiError> { - Ok(json_response(StatusCode::OK, "")?) + json_response(StatusCode::OK, "") } fn make_router() -> RouterBuilder { diff --git a/proxy/src/main.rs b/proxy/src/main.rs index de618ccde9..bd99d0a639 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -20,13 +20,14 @@ mod cancellation; mod compute; mod config; mod cplane_api; +mod error; mod http; mod mgmt; mod proxy; mod stream; mod waiters; -/// Flattens Result> into Result. +/// Flattens `Result>` into `Result`. async fn flatten_err( f: impl Future, JoinError>>, ) -> anyhow::Result<()> { diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 55b49b441f..e53542dfd2 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -79,6 +79,18 @@ enum PsqlSessionResult { Failure(String), } +/// A message received by `mgmt` when a compute node is ready. +pub type ComputeReady = Result; + +impl PsqlSessionResult { + fn into_compute_ready(self) -> ComputeReady { + match self { + Self::Success(db_info) => Ok(db_info), + Self::Failure(message) => Err(message), + } + } +} + impl postgres_backend::Handler for MgmtHandler { fn process_query( &mut self, @@ -99,13 +111,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R let resp: PsqlSessionResponse = serde_json::from_str(query_string)?; - use PsqlSessionResult::*; - let msg = match resp.result { - Success(db_info) => Ok(db_info), - Failure(message) => Err(message), - }; - - match cplane_api::notify(&resp.session_id, msg) { + match cplane_api::notify(&resp.session_id, resp.result.into_compute_ready()) { Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1dc301b792..3c7f59bc26 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,17 +1,18 @@ use crate::auth; -use crate::cancellation::{self, CancelClosure, CancelMap}; -use crate::compute::DatabaseInfo; +use crate::cancellation::{self, CancelMap}; use crate::config::{ProxyConfig, TlsConfig}; use crate::stream::{MetricsStream, PqStream, Stream}; use anyhow::{bail, Context}; +use futures::TryFutureExt; use lazy_static::lazy_static; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpStream; -use tokio_postgres::NoTls; use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter}; use zenith_utils::pq_proto::{BeMessage as Be, *}; +const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; +const ERR_PROTO_VIOLATION: &str = "protocol violation"; + lazy_static! { static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!( new_common_metric_name("num_connections_accepted"), @@ -30,6 +31,7 @@ lazy_static! { .unwrap(); } +/// A small combinator for pluggable error logging. async fn log_error(future: F) -> F::Output where F: std::future::Future>, @@ -76,20 +78,21 @@ async fn handle_client( } let tls = config.tls_config.clone(); - if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? { - cancel_map - .with_session(|session| async { - connect_client_to_db(config, session, client, creds).await - }) - .await?; - } + let (stream, creds) = match handshake(stream, tls, cancel_map).await? { + Some(x) => x, + None => return Ok(()), // it's a cancellation request + }; - Ok(()) + let client = Client::new(stream, creds); + cancel_map + .with_session(|session| client.connect_to_db(config, session)) + .await } -/// Handle a connection from one client. -/// For better testing experience, `stream` can be -/// any object satisfying the traits. +/// Establish a (most probably, secure) connection with the client. +/// For better testing experience, `stream` can be any object satisfying the traits. +/// It's easier to work with owned `stream` here as we need to updgrade it to TLS; +/// we also take an extra care of propagating only the select handshake errors to client. async fn handshake( stream: S, mut tls: Option, @@ -119,7 +122,7 @@ async fn handshake( stream = PqStream::new(stream.into_inner().upgrade(tls).await?); } } - _ => bail!("protocol violation"), + _ => bail!(ERR_PROTO_VIOLATION), }, GssEncRequest => match stream.get_ref() { Stream::Raw { .. } if !tried_gss => { @@ -128,18 +131,21 @@ async fn handshake( // Currently, we don't support GSSAPI stream.write_message(&Be::EncryptionResponse(false)).await?; } - _ => bail!("protocol violation"), + _ => bail!(ERR_PROTO_VIOLATION), }, StartupMessage { params, .. } => { // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - let msg = "connection is insecure (try using `sslmode=require`)"; - stream.write_message(&Be::ErrorResponse(msg)).await?; - bail!(msg); + stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; } - break Ok(Some((stream, params.try_into()?))); + // Here and forth: `or_else` demands that we use a future here + let creds = async { params.try_into() } + .or_else(|e| stream.throw_error(e)) + .await?; + + break Ok(Some((stream, creds))); } CancelRequest(cancel_key_data) => { cancel_map.cancel_session(cancel_key_data).await?; @@ -150,58 +156,60 @@ async fn handshake( } } -async fn connect_client_to_db( - config: &ProxyConfig, - session: cancellation::Session<'_>, - mut client: PqStream, +/// Thin connection context. +struct Client { + /// The underlying libpq protocol stream. + stream: PqStream, + /// Client credentials that we care about. creds: auth::ClientCredentials, -) -> anyhow::Result<()> { - let db_info = creds.authenticate(config, &mut client).await?; - let (db, version, cancel_closure) = connect_to_db(db_info).await?; - let cancel_key_data = session.enable_cancellation(cancel_closure); - - client - .write_message_noflush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ServerVersion(&version), - ))? - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&BeMessage::ReadyForQuery) - .await?; - - // This function will be called for writes to either direction. - fn inc_proxied(cnt: usize) { - // Consider inventing something more sophisticated - // if this ever becomes a bottleneck (cacheline bouncing). - NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64); - } - - let mut db = MetricsStream::new(db, inc_proxied); - let mut client = MetricsStream::new(client.into_inner(), inc_proxied); - let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; - - Ok(()) } -/// Connect to a corresponding compute node. -async fn connect_to_db( - db_info: DatabaseInfo, -) -> anyhow::Result<(TcpStream, String, CancelClosure)> { - // TODO: establish a secure connection to the DB - let socket_addr = db_info.socket_addr()?; - let mut socket = TcpStream::connect(socket_addr).await?; +impl Client { + /// Construct a new connection context. + fn new(stream: PqStream, creds: auth::ClientCredentials) -> Self { + Self { stream, creds } + } +} - let (client, conn) = tokio_postgres::Config::from(db_info) - .connect_raw(&mut socket, NoTls) - .await?; +impl Client { + /// Let the client authenticate and connect to the designated compute node. + async fn connect_to_db( + self, + config: &ProxyConfig, + session: cancellation::Session<'_>, + ) -> anyhow::Result<()> { + let Self { mut stream, creds } = self; - let version = conn - .parameter("server_version") - .context("failed to fetch postgres server version")? - .into(); + // Authenticate and connect to a compute node. + let auth = creds.authenticate(config, &mut stream).await; + let db_info = async { auth }.or_else(|e| stream.throw_error(e)).await?; - let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token()); + let (db, version, cancel_closure) = + db_info.connect().or_else(|e| stream.throw_error(e)).await?; + let cancel_key_data = session.enable_cancellation(cancel_closure); - Ok((socket, version, cancel_closure)) + stream + .write_message_noflush(&BeMessage::ParameterStatus( + BeParameterStatusMessage::ServerVersion(&version), + ))? + .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? + .write_message(&BeMessage::ReadyForQuery) + .await?; + + /// This function will be called for writes to either direction. + fn inc_proxied(cnt: usize) { + // Consider inventing something more sophisticated + // if this ever becomes a bottleneck (cacheline bouncing). + NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64); + } + + // Starting from here we only proxy the client's traffic. + let mut db = MetricsStream::new(db, inc_proxied); + let mut client = MetricsStream::new(stream.into_inner(), inc_proxied); + let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; + + Ok(()) + } } #[cfg(test)] @@ -210,7 +218,7 @@ mod tests { use tokio::io::DuplexStream; use tokio_postgres::config::SslMode; - use tokio_postgres::tls::MakeTlsConnect; + use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres_rustls::MakeRustlsConnect; async fn dummy_proxy( @@ -264,7 +272,7 @@ mod tests { let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); - tokio_postgres::Config::new() + let client_err = tokio_postgres::Config::new() .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Disable) @@ -273,11 +281,15 @@ mod tests { .err() // -> Option .context("client shouldn't be able to connect")?; - proxy + assert!(client_err.to_string().contains(ERR_INSECURE_CONNECTION)); + + let server_err = proxy .await? .err() // -> Option .context("server shouldn't accept client")?; + assert!(client_err.to_string().contains(&server_err.to_string())); + Ok(()) } @@ -329,4 +341,30 @@ mod tests { proxy.await? } + + #[tokio::test] + async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let proxy = tokio::spawn(dummy_proxy(client, None)); + + let client_err = tokio_postgres::Config::new() + .ssl_mode(SslMode::Disable) + .connect_raw(server, NoTls) + .await + .err() // -> Option + .context("client shouldn't be able to connect")?; + + // TODO: this is ugly, but `format!` won't allow us to extract fmt string + assert!(client_err.to_string().contains("missing in startup packet")); + + let server_err = proxy + .await? + .err() // -> Option + .context("server shouldn't accept client")?; + + assert!(client_err.to_string().contains(&server_err.to_string())); + + Ok(()) + } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 8fd5bef388..fb0be84584 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,10 +1,12 @@ -use anyhow::Context; +use crate::error::UserFacingError; +use anyhow::bail; use bytes::BytesMut; use pin_project_lite::pin_project; use rustls::ServerConfig; use std::pin::Pin; use std::sync::Arc; use std::{io, task}; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_rustls::server::TlsStream; use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket}; @@ -35,38 +37,63 @@ impl PqStream { self.stream } - /// Get a reference to the underlying stream. + /// Get a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.stream } } +fn err_connection() -> io::Error { + io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +} + +// TODO: change error type of `FeMessage::read_fut` +fn from_anyhow(e: anyhow::Error) -> io::Error { + io::Error::new(io::ErrorKind::Other, e.to_string()) +} + impl PqStream { /// Receive [`FeStartupPacket`], which is a first packet sent by a client. - pub async fn read_startup_packet(&mut self) -> anyhow::Result { - match FeStartupPacket::read_fut(&mut self.stream).await? { - Some(FeMessage::StartupPacket(packet)) => Ok(packet), - None => anyhow::bail!("connection is lost"), - other => anyhow::bail!("bad message type: {:?}", other), + pub async fn read_startup_packet(&mut self) -> io::Result { + // TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket` + let msg = FeStartupPacket::read_fut(&mut self.stream) + .await + .map_err(from_anyhow)? + .ok_or_else(err_connection)?; + + match msg { + FeMessage::StartupPacket(packet) => Ok(packet), + _ => panic!("unreachable state"), } } - pub async fn read_message(&mut self) -> anyhow::Result { + pub async fn read_password_message(&mut self) -> io::Result { + match self.read_message().await? { + FeMessage::PasswordMessage(msg) => Ok(msg), + bad => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected message type: {:?}", bad), + )), + } + } + + async fn read_message(&mut self) -> io::Result { FeMessage::read_fut(&mut self.stream) - .await? - .context("connection is lost") + .await + .map_err(from_anyhow)? + .ok_or_else(err_connection) } } impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. - pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> { + pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { BeMessage::write(&mut self.buffer, message)?; Ok(self) } /// Write the message into an internal buffer and flush it. - pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> { + pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { self.write_message_noflush(message)?; self.flush().await?; Ok(self) @@ -79,6 +106,25 @@ impl PqStream { self.stream.flush().await?; Ok(self) } + + /// Write the error message using [`Self::write_message`], then re-throw it. + /// Allowing string literals is safe under the assumption they might not contain any runtime info. + pub async fn throw_error_str(&mut self, error: &'static str) -> anyhow::Result { + // This method exists due to `&str` not implementing `Into` + self.write_message(&BeMessage::ErrorResponse(error)).await?; + bail!(error) + } + + /// Write the error message using [`Self::write_message`], then re-throw it. + /// Trait [`UserFacingError`] acts as an allowlist for error types. + pub async fn throw_error(&mut self, error: E) -> anyhow::Result + where + E: UserFacingError + Into, + { + let msg = error.to_string_client(); + self.write_message(&BeMessage::ErrorResponse(&msg)).await?; + bail!(error) + } } pin_project! { @@ -101,15 +147,25 @@ impl Stream { } } +#[derive(Debug, Error)] +#[error("Can't upgrade TLS stream")] +pub enum StreamUpgradeError { + #[error("Bad state reached: can't upgrade TLS stream")] + AlreadyTls, + + #[error("Can't upgrade stream: IO error: {0}")] + Io(#[from] io::Error), +} + impl Stream { /// If possible, upgrade raw stream into a secure TLS-based stream. - pub async fn upgrade(self, cfg: Arc) -> anyhow::Result { + pub async fn upgrade(self, cfg: Arc) -> Result { match self { Stream::Raw { raw } => { let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?); Ok(Stream::Tls { tls }) } - Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"), + Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls), } } } diff --git a/proxy/src/waiters.rs b/proxy/src/waiters.rs index 9fda3ed94f..799d45a165 100644 --- a/proxy/src/waiters.rs +++ b/proxy/src/waiters.rs @@ -1,11 +1,32 @@ -use anyhow::{anyhow, Context}; use hashbrown::HashMap; use parking_lot::Mutex; use pin_project_lite::pin_project; use std::pin::Pin; use std::task; +use thiserror::Error; use tokio::sync::oneshot; +#[derive(Debug, Error)] +pub enum RegisterError { + #[error("Waiter `{0}` already registered")] + Occupied(String), +} + +#[derive(Debug, Error)] +pub enum NotifyError { + #[error("Notify failed: waiter `{0}` not registered")] + NotFound(String), + + #[error("Notify failed: channel hangup")] + Hangup, +} + +#[derive(Debug, Error)] +pub enum WaitError { + #[error("Wait failed: channel hangup")] + Hangup, +} + pub struct Waiters(pub(self) Mutex>>); impl Default for Waiters { @@ -15,13 +36,13 @@ impl Default for Waiters { } impl Waiters { - pub fn register(&self, key: String) -> anyhow::Result> { + pub fn register(&self, key: String) -> Result, RegisterError> { let (tx, rx) = oneshot::channel(); self.0 .lock() .try_insert(key.clone(), tx) - .map_err(|_| anyhow!("waiter already registered"))?; + .map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?; Ok(Waiter { receiver: rx, @@ -32,7 +53,7 @@ impl Waiters { }) } - pub fn notify(&self, key: &str, value: T) -> anyhow::Result<()> + pub fn notify(&self, key: &str, value: T) -> Result<(), NotifyError> where T: Send + Sync, { @@ -40,9 +61,9 @@ impl Waiters { .0 .lock() .remove(key) - .with_context(|| format!("key {} not found", key))?; + .ok_or_else(|| NotifyError::NotFound(key.to_string()))?; - tx.send(value).map_err(|_| anyhow!("waiter channel hangup")) + tx.send(value).map_err(|_| NotifyError::Hangup) } } @@ -66,13 +87,13 @@ pin_project! { } impl std::future::Future for Waiter<'_, T> { - type Output = anyhow::Result; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { self.project() .receiver .poll(cx) - .map_err(|_| anyhow!("channel hangup")) + .map_err(|_| WaitError::Hangup) } }