mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-31 12:00:42 +00:00
Proxy error reworking (#6453)
## Problem Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and doing a bit less radical changes. smaller commits. We currently don't report error classifications in proxy as the current error handling made it hard to do so. ## Summary of changes 1. Add a `ReportableError` trait that all errors will implement. This provides the error classification functionality. 2. Handle Client requests a strongly typed error * this error is a `ReportableError` and is logged appropriately 3. The handle client error only has a few possible error types, to account for the fact that at this point errors should be returned to the user.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
@@ -29,9 +28,11 @@ use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::HTTP_CONTENT_LENGTH;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
@@ -41,7 +42,6 @@ use super::backend::PoolingBackend;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::SERVERLESS_DRIVER_SNI;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -86,67 +86,70 @@ where
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static str),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("missing password")]
|
||||
MissingPassword,
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
#[error("malformed endpoint")]
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
) -> Result<ConnInfo, ConnInfoError> {
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
.ok_or(anyhow::anyhow!("missing connection string"))?
|
||||
.to_str()?;
|
||||
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(anyhow::anyhow!(
|
||||
"connection string must start with postgres: or postgresql:"
|
||||
));
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(anyhow::anyhow!("missing database name"))?;
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname = url_path
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
|
||||
|
||||
let username = RoleName::from(connection_url.username());
|
||||
if username.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(anyhow::anyhow!("no password"))?;
|
||||
|
||||
// TLS certificate selector now based on SNI hostname, so if we are running here
|
||||
// we are sure that SNI hostname is set to one of the configured domain names.
|
||||
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
|
||||
.ok_or(ConnInfoError::MissingPassword)?;
|
||||
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
.ok_or(anyhow::anyhow!("no host"))?;
|
||||
.ok_or(ConnInfoError::MissingHostname)?;
|
||||
|
||||
let host_header = headers
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next());
|
||||
|
||||
// sni_hostname has to be either the same as hostname or the one used in serverless driver.
|
||||
if !check_matches(&sni_hostname, hostname)? {
|
||||
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
|
||||
} else if let Some(h) = host_header {
|
||||
if h != sni_hostname {
|
||||
return Err(anyhow::anyhow!("mismatched host header and hostname"));
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
|
||||
let endpoint =
|
||||
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
@@ -173,36 +176,27 @@ fn get_conn_info(
|
||||
})
|
||||
}
|
||||
|
||||
fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
|
||||
if sni_hostname == hostname {
|
||||
return Ok(true);
|
||||
}
|
||||
let (sni_hostname_first, sni_hostname_rest) = sni_hostname
|
||||
.split_once('.')
|
||||
.ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
|
||||
let (_, hostname_rest) = hostname
|
||||
.split_once('.')
|
||||
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
|
||||
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
mut ctx: RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.http_config.request_timeout,
|
||||
handle_inner(config, ctx, request, sni_hostname, backend),
|
||||
handle_inner(config, &mut ctx, request, backend),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
Ok(r) => match r {
|
||||
Ok(r) => r,
|
||||
Ok(r) => {
|
||||
ctx.set_success();
|
||||
r
|
||||
}
|
||||
Err(e) => {
|
||||
// TODO: ctx.set_error_kind(e.get_error_type());
|
||||
|
||||
let mut message = format!("{:?}", e);
|
||||
let db_error = e
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
@@ -278,7 +272,9 @@ pub async fn handle(
|
||||
)?
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
|
||||
let message = format!(
|
||||
"HTTP-Connection timed out, execution time exeeded {} seconds",
|
||||
config.http_config.request_timeout.as_secs()
|
||||
@@ -290,6 +286,7 @@ pub async fn handle(
|
||||
)?
|
||||
}
|
||||
};
|
||||
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
@@ -302,7 +299,6 @@ async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
@@ -318,12 +314,7 @@ async fn handle_inner(
|
||||
//
|
||||
let headers = request.headers();
|
||||
// TLS config should be there.
|
||||
let conn_info = get_conn_info(
|
||||
ctx,
|
||||
headers,
|
||||
sni_hostname,
|
||||
config.tls_config.as_ref().unwrap(),
|
||||
)?;
|
||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
|
||||
info!(
|
||||
user = conn_info.user_info.user.as_str(),
|
||||
project = conn_info.user_info.endpoint.as_str(),
|
||||
@@ -487,8 +478,6 @@ async fn handle_inner(
|
||||
}
|
||||
};
|
||||
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
let metrics = client.metrics();
|
||||
|
||||
// how could this possibly fail
|
||||
|
||||
Reference in New Issue
Block a user