Merge pull request #8117 from neondatabase/rc/proxy/2024-06-20

Proxy release 2024-06-20
This commit is contained in:
Conrad Ludgate
2024-06-20 11:42:35 +01:00
committed by GitHub
195 changed files with 6122 additions and 1926 deletions

View File

@@ -1,16 +1,183 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::proxy::retry::ShouldRetry;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
#[derive(Debug, Deserialize)]
pub struct ConsoleError {
pub error: Box<str>,
#[serde(skip)]
pub http_status_code: http::StatusCode,
pub status: Option<Status>,
}
impl ConsoleError {
pub fn get_reason(&self) -> Reason {
self.status
.as_ref()
.and_then(|s| s.details.error_info.as_ref())
.map(|e| e.reason)
.unwrap_or(Reason::Unknown)
}
pub fn get_user_facing_message(&self) -> String {
use super::provider::errors::REQUEST_FAILED;
self.status
.as_ref()
.and_then(|s| s.details.user_facing_message.as_ref())
.map(|m| m.message.clone().into())
.unwrap_or_else(|| {
// Ask @neondatabase/control-plane for review before adding more.
match self.http_status_code {
http::StatusCode::NOT_FOUND => {
// Status 404: failed to get a project-related resource.
format!("{REQUEST_FAILED}: endpoint cannot be found")
}
http::StatusCode::NOT_ACCEPTABLE => {
// Status 406: endpoint is disabled (we don't allow connections).
format!("{REQUEST_FAILED}: endpoint is disabled")
}
http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
// Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
}
_ => REQUEST_FAILED.to_owned(),
}
})
}
}
impl Display for ConsoleError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let msg = self
.status
.as_ref()
.and_then(|s| s.details.user_facing_message.as_ref())
.map(|m| m.message.as_ref())
.unwrap_or_else(|| &self.error);
write!(f, "{}", msg)
}
}
impl ShouldRetry for ConsoleError {
fn could_retry(&self) -> bool {
if self.status.is_none() || self.status.as_ref().unwrap().details.retry_info.is_none() {
// retry some temporary failures because the compute was in a bad state
// (bad request can be returned when the endpoint was in transition)
return match &self {
ConsoleError {
http_status_code: http::StatusCode::BAD_REQUEST,
..
} => true,
// don't retry when quotas are exceeded
ConsoleError {
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
ref error,
..
} => !error.contains("compute time quota of non-primary branches is exceeded"),
// locked can be returned when the endpoint was in transition
// or when quotas are exceeded. don't retry when quotas are exceeded
ConsoleError {
http_status_code: http::StatusCode::LOCKED,
ref error,
..
} => {
!error.contains("quota exceeded")
&& !error.contains("the limit for current plan reached")
}
_ => false,
};
}
// retry if the response has a retry delay
if let Some(retry_info) = self
.status
.as_ref()
.and_then(|s| s.details.retry_info.as_ref())
{
retry_info.retry_delay_ms > 0
} else {
false
}
}
}
#[derive(Debug, Deserialize)]
pub struct Status {
pub code: Box<str>,
pub message: Box<str>,
pub details: Details,
}
#[derive(Debug, Deserialize)]
pub struct Details {
pub error_info: Option<ErrorInfo>,
pub retry_info: Option<RetryInfo>,
pub user_facing_message: Option<UserFacingMessage>,
}
#[derive(Debug, Deserialize)]
pub struct ErrorInfo {
pub reason: Reason,
// Schema could also have `metadata` field, but it's not structured. Skip it for now.
}
#[derive(Clone, Copy, Debug, Deserialize, Default)]
pub enum Reason {
#[serde(rename = "ROLE_PROTECTED")]
RoleProtected,
#[serde(rename = "RESOURCE_NOT_FOUND")]
ResourceNotFound,
#[serde(rename = "PROJECT_NOT_FOUND")]
ProjectNotFound,
#[serde(rename = "ENDPOINT_NOT_FOUND")]
EndpointNotFound,
#[serde(rename = "BRANCH_NOT_FOUND")]
BranchNotFound,
#[serde(rename = "RATE_LIMIT_EXCEEDED")]
RateLimitExceeded,
#[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
NonPrimaryBranchComputeTimeExceeded,
#[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
ActiveTimeQuotaExceeded,
#[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
ComputeTimeQuotaExceeded,
#[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
WrittenDataQuotaExceeded,
#[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
DataTransferQuotaExceeded,
#[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
LogicalSizeQuotaExceeded,
#[default]
#[serde(other)]
Unknown,
}
impl Reason {
pub fn is_not_found(&self) -> bool {
matches!(
self,
Reason::ResourceNotFound
| Reason::ProjectNotFound
| Reason::EndpointNotFound
| Reason::BranchNotFound
)
}
}
#[derive(Debug, Deserialize)]
pub struct RetryInfo {
pub retry_delay_ms: u64,
}
#[derive(Debug, Deserialize)]
pub struct UserFacingMessage {
pub message: Box<str>,
}
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].

View File

@@ -25,8 +25,8 @@ use tracing::info;
pub mod errors {
use crate::{
console::messages::{self, ConsoleError},
error::{io_error, ReportableError, UserFacingError},
http,
proxy::retry::ShouldRetry,
};
use thiserror::Error;
@@ -34,17 +34,14 @@ pub mod errors {
use super::ApiLockError;
/// A go-to error message which doesn't leak any detail.
const REQUEST_FAILED: &str = "Console request failed";
pub const REQUEST_FAILED: &str = "Console request failed";
/// Common console API error.
#[derive(Debug, Error)]
pub enum ApiError {
/// Error returned by the console itself.
#[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
Console {
status: http::StatusCode,
text: Box<str>,
},
#[error("{REQUEST_FAILED} with {0}")]
Console(ConsoleError),
/// Various IO errors like broken pipe or malformed payload.
#[error("{REQUEST_FAILED}: {0}")]
@@ -53,11 +50,11 @@ pub mod errors {
impl ApiError {
/// Returns HTTP status code if it's the reason for failure.
pub fn http_status_code(&self) -> Option<http::StatusCode> {
pub fn get_reason(&self) -> messages::Reason {
use ApiError::*;
match self {
Console { status, .. } => Some(*status),
_ => None,
Console(e) => e.get_reason(),
_ => messages::Reason::Unknown,
}
}
}
@@ -67,22 +64,7 @@ pub mod errors {
use ApiError::*;
match self {
// To minimize risks, only select errors are forwarded to users.
// Ask @neondatabase/control-plane for review before adding more.
Console { status, .. } => match *status {
http::StatusCode::NOT_FOUND => {
// Status 404: failed to get a project-related resource.
format!("{REQUEST_FAILED}: endpoint cannot be found")
}
http::StatusCode::NOT_ACCEPTABLE => {
// Status 406: endpoint is disabled (we don't allow connections).
format!("{REQUEST_FAILED}: endpoint is disabled")
}
http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
// Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
}
_ => REQUEST_FAILED.to_owned(),
},
Console(c) => c.get_user_facing_message(),
_ => REQUEST_FAILED.to_owned(),
}
}
@@ -91,29 +73,56 @@ pub mod errors {
impl ReportableError for ApiError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console {
status: http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
..
} => crate::error::ErrorKind::User,
ApiError::Console {
status: http::StatusCode::UNPROCESSABLE_ENTITY,
text,
} if text.contains("compute time quota of non-primary branches is exceeded") => {
crate::error::ErrorKind::User
ApiError::Console(e) => {
use crate::error::ErrorKind::*;
match e.get_reason() {
crate::console::messages::Reason::RoleProtected => User,
crate::console::messages::Reason::ResourceNotFound => User,
crate::console::messages::Reason::ProjectNotFound => User,
crate::console::messages::Reason::EndpointNotFound => User,
crate::console::messages::Reason::BranchNotFound => User,
crate::console::messages::Reason::RateLimitExceeded => ServiceRateLimit,
crate::console::messages::Reason::NonPrimaryBranchComputeTimeExceeded => {
User
}
crate::console::messages::Reason::ActiveTimeQuotaExceeded => User,
crate::console::messages::Reason::ComputeTimeQuotaExceeded => User,
crate::console::messages::Reason::WrittenDataQuotaExceeded => User,
crate::console::messages::Reason::DataTransferQuotaExceeded => User,
crate::console::messages::Reason::LogicalSizeQuotaExceeded => User,
crate::console::messages::Reason::Unknown => match &e {
ConsoleError {
http_status_code:
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
..
} => crate::error::ErrorKind::User,
ConsoleError {
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
error,
..
} if error.contains(
"compute time quota of non-primary branches is exceeded",
) =>
{
crate::error::ErrorKind::User
}
ConsoleError {
http_status_code: http::StatusCode::LOCKED,
error,
..
} if error.contains("quota exceeded")
|| error.contains("the limit for current plan reached") =>
{
crate::error::ErrorKind::User
}
ConsoleError {
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
..
} => crate::error::ErrorKind::ServiceRateLimit,
ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
},
}
}
ApiError::Console {
status: http::StatusCode::LOCKED,
text,
} if text.contains("quota exceeded")
|| text.contains("the limit for current plan reached") =>
{
crate::error::ErrorKind::User
}
ApiError::Console {
status: http::StatusCode::TOO_MANY_REQUESTS,
..
} => crate::error::ErrorKind::ServiceRateLimit,
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
@@ -124,31 +133,7 @@ pub mod errors {
match self {
// retry some transport errors
Self::Transport(io) => io.could_retry(),
// retry some temporary failures because the compute was in a bad state
// (bad request can be returned when the endpoint was in transition)
Self::Console {
status: http::StatusCode::BAD_REQUEST,
..
} => true,
// don't retry when quotas are exceeded
Self::Console {
status: http::StatusCode::UNPROCESSABLE_ENTITY,
ref text,
} => !text.contains("compute time quota of non-primary branches is exceeded"),
// locked can be returned when the endpoint was in transition
// or when quotas are exceeded. don't retry when quotas are exceeded
Self::Console {
status: http::StatusCode::LOCKED,
ref text,
} => {
// written data quota exceeded
// data transfer quota exceeded
// compute time quota exceeded
// logical size quota exceeded
!text.contains("quota exceeded")
&& !text.contains("the limit for current plan reached")
}
_ => false,
Self::Console(e) => e.could_retry(),
}
}
}
@@ -509,7 +494,7 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
self.metrics
.semaphore_acquire_seconds
.observe(now.elapsed().as_secs_f64());
info!("acquired permit {:?}", now.elapsed().as_secs_f64());
Ok(WakeComputePermit { permit: permit? })
}

View File

@@ -94,12 +94,14 @@ impl Api {
let body = match parse_body::<GetRoleSecret>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(http::StatusCode::NOT_FOUND) => {
// TODO(anna): retry
Err(e) => {
if e.get_reason().is_not_found() {
return Ok(AuthInfo::default());
} else {
return Err(e.into());
}
_otherwise => return Err(e.into()),
},
}
};
let secret = if body.role_secret.is_empty() {
@@ -328,19 +330,24 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
info!("request succeeded, processing the body");
return Ok(response.json().await?);
}
let s = response.bytes().await?;
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
info!("response_error plaintext: {:?}", s);
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let body = response.json().await.unwrap_or_else(|e| {
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
ConsoleError {
error: "reason unclear (malformed error message)".into(),
http_status_code: status,
status: None,
}
});
body.http_status_code = status;
let text = body.error;
error!("console responded with an error ({status}): {text}");
Err(ApiError::Console { status, text })
error!("console responded with an error ({status}): {body:?}");
Err(ApiError::Console(body))
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {

View File

@@ -91,7 +91,7 @@ pub async fn task_main(
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await{
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Ok((socket, Some(addr))) => (socket, addr.ip()),
Err(e) => {
error!("per-client task finished with an error: {e:#}");
@@ -101,36 +101,38 @@ pub async fn task_main(
error!("missing required client IP");
return;
}
Ok((socket, None)) => (socket, peer_addr.ip())
Ok((socket, None)) => (socket, peer_addr.ip()),
};
match socket.inner.set_nodelay(true) {
Ok(()) => {},
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
return;
},
}
};
let mut ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span.clone();
let res = handle_client(
config,
&mut ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone())
.await;
let startup = Box::pin(
handle_client(
config,
&mut ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone()),
);
let res = startup.await;
match res {
Err(e) => {

View File

@@ -98,7 +98,7 @@ pub(super) struct CopyBuffer {
amt: u64,
buf: Box<[u8]>,
}
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
const DEFAULT_BUF_SIZE: usize = 1024;
impl CopyBuffer {
pub(super) fn new() -> Self {

View File

@@ -12,7 +12,7 @@ use crate::auth::backend::{
};
use crate::config::{CertResolver, RetryConfig};
use crate::console::caches::NodeInfoCache;
use crate::console::messages::MetricsAuxInfo;
use crate::console::messages::{ConsoleError, MetricsAuxInfo};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
@@ -484,18 +484,20 @@ impl TestBackend for TestConnectMechanism {
match action {
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
ConnectAction::WakeFail => {
let err = console::errors::ApiError::Console {
status: http::StatusCode::FORBIDDEN,
text: "TEST".into(),
};
let err = console::errors::ApiError::Console(ConsoleError {
http_status_code: http::StatusCode::FORBIDDEN,
error: "TEST".into(),
status: None,
});
assert!(!err.could_retry());
Err(console::errors::WakeComputeError::ApiError(err))
}
ConnectAction::WakeRetry => {
let err = console::errors::ApiError::Console {
status: http::StatusCode::BAD_REQUEST,
text: "TEST".into(),
};
let err = console::errors::ApiError::Console(ConsoleError {
http_status_code: http::StatusCode::BAD_REQUEST,
error: "TEST".into(),
status: None,
});
assert!(err.could_retry());
Err(console::errors::WakeComputeError::ApiError(err))
}

View File

@@ -1,4 +1,5 @@
use crate::config::RetryConfig;
use crate::console::messages::ConsoleError;
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
use crate::context::RequestMonitoring;
use crate::metrics::{
@@ -88,36 +89,76 @@ fn report_error(e: &WakeComputeError, retry: bool) {
let kind = match e {
WakeComputeError::BadComputeAddress(_) => WakeupFailureKind::BadComputeAddress,
WakeComputeError::ApiError(ApiError::Transport(_)) => WakeupFailureKind::ApiTransportError,
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
ref text,
}) if text.contains("written data quota exceeded")
|| text.contains("the limit for current plan reached") =>
{
WakeupFailureKind::QuotaExceeded
}
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::UNPROCESSABLE_ENTITY,
ref text,
}) if text.contains("compute time quota of non-primary branches is exceeded") => {
WakeupFailureKind::QuotaExceeded
}
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
..
}) => WakeupFailureKind::ApiConsoleLocked,
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::BAD_REQUEST,
..
}) => WakeupFailureKind::ApiConsoleBadRequest,
WakeComputeError::ApiError(ApiError::Console { status, .. })
if status.is_server_error() =>
{
WakeupFailureKind::ApiConsoleOtherServerError
}
WakeComputeError::ApiError(ApiError::Console { .. }) => {
WakeupFailureKind::ApiConsoleOtherError
}
WakeComputeError::ApiError(ApiError::Console(e)) => match e.get_reason() {
crate::console::messages::Reason::RoleProtected => {
WakeupFailureKind::ApiConsoleBadRequest
}
crate::console::messages::Reason::ResourceNotFound => {
WakeupFailureKind::ApiConsoleBadRequest
}
crate::console::messages::Reason::ProjectNotFound => {
WakeupFailureKind::ApiConsoleBadRequest
}
crate::console::messages::Reason::EndpointNotFound => {
WakeupFailureKind::ApiConsoleBadRequest
}
crate::console::messages::Reason::BranchNotFound => {
WakeupFailureKind::ApiConsoleBadRequest
}
crate::console::messages::Reason::RateLimitExceeded => {
WakeupFailureKind::ApiConsoleLocked
}
crate::console::messages::Reason::NonPrimaryBranchComputeTimeExceeded => {
WakeupFailureKind::QuotaExceeded
}
crate::console::messages::Reason::ActiveTimeQuotaExceeded => {
WakeupFailureKind::QuotaExceeded
}
crate::console::messages::Reason::ComputeTimeQuotaExceeded => {
WakeupFailureKind::QuotaExceeded
}
crate::console::messages::Reason::WrittenDataQuotaExceeded => {
WakeupFailureKind::QuotaExceeded
}
crate::console::messages::Reason::DataTransferQuotaExceeded => {
WakeupFailureKind::QuotaExceeded
}
crate::console::messages::Reason::LogicalSizeQuotaExceeded => {
WakeupFailureKind::QuotaExceeded
}
crate::console::messages::Reason::Unknown => match e {
ConsoleError {
http_status_code: StatusCode::LOCKED,
ref error,
..
} if error.contains("written data quota exceeded")
|| error.contains("the limit for current plan reached") =>
{
WakeupFailureKind::QuotaExceeded
}
ConsoleError {
http_status_code: StatusCode::UNPROCESSABLE_ENTITY,
ref error,
..
} if error.contains("compute time quota of non-primary branches is exceeded") => {
WakeupFailureKind::QuotaExceeded
}
ConsoleError {
http_status_code: StatusCode::LOCKED,
..
} => WakeupFailureKind::ApiConsoleLocked,
ConsoleError {
http_status_code: StatusCode::BAD_REQUEST,
..
} => WakeupFailureKind::ApiConsoleBadRequest,
ConsoleError {
http_status_code, ..
} if http_status_code.is_server_error() => {
WakeupFailureKind::ApiConsoleOtherServerError
}
ConsoleError { .. } => WakeupFailureKind::ApiConsoleOtherError,
},
},
WakeComputeError::TooManyConnections => WakeupFailureKind::ApiConsoleLocked,
WakeComputeError::TooManyConnectionAttempts(_) => WakeupFailureKind::TimeoutError,
};

View File

@@ -1,5 +1,3 @@
use std::usize;
use super::{LimitAlgorithm, Outcome, Sample};
/// Loss-based congestion avoidance.

View File

@@ -32,8 +32,6 @@ pub struct ClientFirstMessage<'a> {
pub bare: &'a str,
/// Channel binding mode.
pub cbind_flag: ChannelBinding<&'a str>,
/// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
pub username: &'a str,
/// Client nonce.
pub nonce: &'a str,
}
@@ -58,6 +56,14 @@ impl<'a> ClientFirstMessage<'a> {
// In theory, these might be preceded by "reserved-mext" (i.e. "m=")
let username = parts.next()?.strip_prefix("n=")?;
// https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
if !username.is_empty() {
tracing::warn!(username, "scram username provided, but is not expected")
// TODO(conrad):
// return None;
}
let nonce = parts.next()?.strip_prefix("r=")?;
// Validate but ignore auth extensions
@@ -66,7 +72,6 @@ impl<'a> ClientFirstMessage<'a> {
Some(Self {
bare,
cbind_flag,
username,
nonce,
})
}
@@ -188,19 +193,18 @@ mod tests {
// (Almost) real strings captured during debug sessions
let cases = [
(NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
(NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
(NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
(NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
(
Required("tls-server-end-point"),
"p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
"p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
),
];
for (cb, input) in cases {
let msg = ClientFirstMessage::parse(input).unwrap();
assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
assert_eq!(msg.username, "pepe");
assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
assert_eq!(msg.cbind_flag, cb);
}
@@ -208,14 +212,13 @@ mod tests {
#[test]
fn parse_client_first_message_with_invalid_gs2_authz() {
assert!(ClientFirstMessage::parse("n,authzid,n=user,r=nonce").is_none())
assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none())
}
#[test]
fn parse_client_first_message_with_extra_params() {
let msg = ClientFirstMessage::parse("n,,n=user,r=nonce,a=foo,b=bar,c=baz").unwrap();
assert_eq!(msg.bare, "n=user,r=nonce,a=foo,b=bar,c=baz");
assert_eq!(msg.username, "user");
let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
assert_eq!(msg.nonce, "nonce");
assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
}
@@ -223,9 +226,9 @@ mod tests {
#[test]
fn parse_client_first_message_with_extra_params_invalid() {
// must be of the form `<ascii letter>=<...>`
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,abc=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,1=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,a").is_none());
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
}
#[test]

View File

@@ -27,14 +27,14 @@ use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tokio_util::task::TaskTracker;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::read_proxy_protocol;
use crate::protocol2::{read_proxy_protocol, ChainRW};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
@@ -102,8 +102,6 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let server = Builder::new(TokioExecutor::new());
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -127,24 +125,50 @@ pub async fn task_main(
}
let conn_token = cancellation_token.child_token();
let conn = connection_handler(
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
conn_token.clone(),
server.clone(),
tls_acceptor.clone(),
conn,
peer_addr,
)
.instrument(http_conn_span);
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
connections.spawn(async move {
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token);
conn.await
});
let session_id = uuid::Uuid::new_v4();
let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);
let startup_result = Box::pin(connection_startup(
config,
tls_acceptor,
session_id,
conn,
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
return;
};
Box::pin(connection_handler(
config,
backend,
connections2,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
session_id,
))
.await;
}
.instrument(http_conn_span),
);
}
connections.wait().await;
@@ -152,40 +176,22 @@ pub async fn task_main(
Ok(())
}
/// Handles the TCP lifecycle.
///
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
/// 3. Handles HTTP connection
/// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder<TokioExecutor>,
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: TlsAcceptor,
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) {
let session_id = uuid::Uuid::new_v4();
let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);
) -> Option<(TlsStream<ChainRW<TcpStream>>, IpAddr)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return;
return None;
}
};
@@ -208,7 +214,7 @@ async fn connection_handler(
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
return None;
}
// The handshake timed out
Err(e) => {
@@ -216,16 +222,36 @@ async fn connection_handler(
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
return None;
}
};
Some((conn, peer_addr))
}
/// Handles HTTP connection
/// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: TlsStream<ChainRW<TcpStream>>,
peer_addr: IpAddr,
session_id: uuid::Uuid,
) {
let session_id = AtomicTake::new(session_id);
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {

View File

@@ -104,7 +104,7 @@ impl PoolingBackend {
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
let maybe_client = if !force_new {
info!("pool: looking for an existing connection");
self.pool.get(ctx, &conn_info).await?
self.pool.get(ctx, &conn_info)?
} else {
info!("pool: pool is disabled");
None

View File

@@ -375,7 +375,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
}
}
pub async fn get(
pub fn get(
self: &Arc<Self>,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,

View File

@@ -45,6 +45,10 @@ pub fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
err.to_string(),
StatusCode::REQUEST_TIMEOUT,
),
ApiError::Cancelled => HttpErrorBody::response_from_msg_and_status(
this.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
),
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,

View File

@@ -533,27 +533,31 @@ async fn handle_inner(
return Err(SqlOverHttpError::RequestTooLarge);
}
let fetch_and_process_request = async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from);
let fetch_and_process_request = Box::pin(
async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
);
let authenticate_and_connect = async {
let keys = backend
.authenticate(ctx, &config.authentication_config, &conn_info)
.await?;
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from);
let authenticate_and_connect = Box::pin(
async {
let keys = backend
.authenticate(ctx, &config.authentication_config, &conn_info)
.await?;
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from),
);
let (payload, mut client) = match run_until_cancelled(
// Run both operations in parallel

View File

@@ -141,7 +141,7 @@ pub async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let res = handle_client(
let res = Box::pin(handle_client(
config,
&mut ctx,
cancellation_handler,
@@ -149,7 +149,7 @@ pub async fn serve_websocket(
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
)
))
.await;
match res {