Proxy error reworking (#6453)

## Problem

Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and
doing a bit less radical changes. smaller commits.

We currently don't report error classifications in proxy as the current
error handling made it hard to do so.

## Summary of changes

1. Add a `ReportableError` trait that all errors will implement. This
provides the error classification functionality.
2. Handle Client requests a strongly typed error
    * this error is a `ReportableError` and is logged appropriately
3. The handle client error only has a few possible error types, to
account for the fact that at this point errors should be returned to the
user.
This commit is contained in:
Conrad Ludgate
2024-02-09 15:50:51 +00:00
committed by GitHub
parent 89a5c654bf
commit 96d89cde51
25 changed files with 588 additions and 186 deletions

View File

@@ -1,7 +1,6 @@
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
@@ -29,9 +28,11 @@ use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
@@ -41,7 +42,6 @@ use super::backend::PoolingBackend;
use super::conn_pool::ConnInfo;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -86,67 +86,70 @@ where
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static str),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("missing password")]
MissingPassword,
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
#[error("malformed endpoint")]
MalformedEndpoint,
}
fn get_conn_info(
ctx: &mut RequestMonitoring,
headers: &HeaderMap,
sni_hostname: Option<String>,
tls: &TlsConfig,
) -> Result<ConnInfo, anyhow::Error> {
) -> Result<ConnInfo, ConnInfoError> {
let connection_string = headers
.get("Neon-Connection-String")
.ok_or(anyhow::anyhow!("missing connection string"))?
.to_str()?;
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(anyhow::anyhow!(
"connection string must start with postgres: or postgresql:"
));
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(anyhow::anyhow!("missing database name"))?;
.ok_or(ConnInfoError::MissingDbName)?;
let dbname = url_path
.next()
.ok_or(anyhow::anyhow!("invalid database name"))?;
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
let username = RoleName::from(connection_url.username());
if username.is_empty() {
return Err(anyhow::anyhow!("missing username"));
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
let password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
// TLS certificate selector now based on SNI hostname, so if we are running here
// we are sure that SNI hostname is set to one of the configured domain names.
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
.ok_or(ConnInfoError::MissingPassword)?;
let hostname = connection_url
.host_str()
.ok_or(anyhow::anyhow!("no host"))?;
.ok_or(ConnInfoError::MissingHostname)?;
let host_header = headers
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next());
// sni_hostname has to be either the same as hostname or the one used in serverless driver.
if !check_matches(&sni_hostname, hostname)? {
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
} else if let Some(h) = host_header {
if h != sni_hostname {
return Err(anyhow::anyhow!("mismatched host header and hostname"));
}
}
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
let endpoint =
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
@@ -173,36 +176,27 @@ fn get_conn_info(
})
}
fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
if sni_hostname == hostname {
return Ok(true);
}
let (sni_hostname_first, sni_hostname_rest) = sni_hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
let (_, hostname_rest) = hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
}
// TODO: return different http error codes
pub async fn handle(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
mut ctx: RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.http_config.request_timeout,
handle_inner(config, ctx, request, sni_hostname, backend),
handle_inner(config, &mut ctx, request, backend),
)
.await;
let mut response = match result {
Ok(r) => match r {
Ok(r) => r,
Ok(r) => {
ctx.set_success();
r
}
Err(e) => {
// TODO: ctx.set_error_kind(e.get_error_type());
let mut message = format!("{:?}", e);
let db_error = e
.downcast_ref::<tokio_postgres::Error>()
@@ -278,7 +272,9 @@ pub async fn handle(
)?
}
},
Err(_) => {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
let message = format!(
"HTTP-Connection timed out, execution time exeeded {} seconds",
config.http_config.request_timeout.as_secs()
@@ -290,6 +286,7 @@ pub async fn handle(
)?
}
};
response.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
@@ -302,7 +299,6 @@ async fn handle_inner(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
backend: Arc<PoolingBackend>,
) -> anyhow::Result<Response<Body>> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
@@ -318,12 +314,7 @@ async fn handle_inner(
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(
ctx,
headers,
sni_hostname,
config.tls_config.as_ref().unwrap(),
)?;
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
info!(
user = conn_info.user_info.user.as_str(),
project = conn_info.user_info.endpoint.as_str(),
@@ -487,8 +478,6 @@ async fn handle_inner(
}
};
ctx.set_success();
ctx.log();
let metrics = client.metrics();
// how could this possibly fail