mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-23 05:12:56 +00:00
Compare commits
3 Commits
erik/commu
...
cloneable/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c90b082222 | ||
|
|
0957c8ea69 | ||
|
|
6cea6560e9 |
@@ -26,9 +26,10 @@ use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
|
||||
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
|
||||
@@ -410,7 +410,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
match auth_backend {
|
||||
Either::Left(auth_backend) => {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(crate::proxy::task_main(
|
||||
client_tasks.spawn(crate::pglb::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
|
||||
@@ -103,6 +103,8 @@ pub enum Auth {
|
||||
}
|
||||
|
||||
/// A config for authenticating to the compute node.
|
||||
// TODO: avoid Clone
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct AuthInfo {
|
||||
/// None for local-proxy, as we use trust-based localhost auth.
|
||||
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
|
||||
|
||||
@@ -13,9 +13,10 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pglb::{ClientRequestError, ErrorSource};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
|
||||
use crate::proxy::prepare_client_connection;
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub async fn task_main(
|
||||
|
||||
@@ -8,10 +8,10 @@ use crate::config::TlsConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::{
|
||||
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use crate::proxy::TlsRequired;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
|
||||
@@ -2,3 +2,342 @@ pub mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod inprocess;
|
||||
pub mod passthrough;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use smol_str::ToSmolStr;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
pub use crate::pglb::copy_bidirectional::ErrorSource;
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism};
|
||||
use crate::proxy::handle_connect_request;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::Stream;
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"per-client task finished with an error: failed to set socket option: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub(crate) fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||
match self {
|
||||
ClientMode::Tcp => tls,
|
||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||
ClientMode::Websockets { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub(crate) enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
client: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(client, params) => (client, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let (node, client, session) = handle_connect_request(
|
||||
config,
|
||||
auth_backend,
|
||||
ctx,
|
||||
cancellation_handler,
|
||||
client,
|
||||
&mode,
|
||||
endpoint_rate_limiter,
|
||||
¶ms,
|
||||
common_names,
|
||||
async |config, ctx, node_info, auth_info, creds, compute_config| {
|
||||
TcpMechanism {
|
||||
auth: auth_info.clone(),
|
||||
locks: &config.connect_compute_locks,
|
||||
user_info: creds.info.clone(),
|
||||
}
|
||||
.connect_once(ctx, node_info, compute_config)
|
||||
.await
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@ use async_trait::async_trait;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::{self, NodeInfo};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||
@@ -183,3 +183,114 @@ where
|
||||
drop(pause);
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute_pglb<
|
||||
F: AsyncFn(
|
||||
&'static ProxyConfig,
|
||||
&RequestContext,
|
||||
&CachedNodeInfo,
|
||||
&AuthInfo,
|
||||
&ComputeCredentials,
|
||||
&ComputeConfig,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError>,
|
||||
B: WakeComputeBackend,
|
||||
>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestContext,
|
||||
connect_compute_fn: F,
|
||||
user_info: &B,
|
||||
auth_info: &AuthInfo,
|
||||
creds: &ComputeCredentials,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
compute: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let mut num_retries = 0;
|
||||
let node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
|
||||
// try once
|
||||
let err = match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
debug!(error = ?err, COULD_NOT_CONNECT);
|
||||
|
||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if should_retry(&err, num_retries, compute.retry) {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Err(err.into());
|
||||
}
|
||||
node_info
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
debug!("compute node's state has likely changed; requesting a wake-up");
|
||||
invalidate_cache(node_info);
|
||||
// TODO: increment num_retries?
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
// this can error for a few reasons, for instance:
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
debug!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
// TODO: is this necessary? We have a metric.
|
||||
info!(?num_retries, "connected to compute node after");
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
if !should_retry(&e, num_retries, compute.retry) {
|
||||
// Don't log an error here, caller will print the error
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);
|
||||
}
|
||||
}
|
||||
|
||||
let wait_duration = retry_after(num_retries, compute.retry);
|
||||
num_retries += 1;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
time::sleep(wait_duration).await;
|
||||
drop(pause);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,311 +5,57 @@ pub(crate) mod connect_compute;
|
||||
pub(crate) mod retry;
|
||||
pub(crate) mod wake_compute;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::cancellation::{CancellationHandler, Session};
|
||||
use crate::compute::{AuthInfo, PostgresConnection};
|
||||
use crate::config::{ComputeConfig, ProxyConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pglb::{ClientMode, ClientRequestError};
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::connect_compute::connect_to_compute_pglb;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::util::run_until_cancelled;
|
||||
use crate::{auth, compute};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"per-client task finished with an error: failed to set socket option: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub(crate) fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||
match self {
|
||||
ClientMode::Tcp => tls,
|
||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||
ClientMode::Websockets { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub(crate) enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
pub(crate) async fn handle_connect_request<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
F: AsyncFn(
|
||||
&'static ProxyConfig,
|
||||
&RequestContext,
|
||||
&CachedNodeInfo,
|
||||
&AuthInfo,
|
||||
&ComputeCredentials,
|
||||
&ComputeConfig,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError>,
|
||||
>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
mut client: PqStream<Stream<S>>,
|
||||
mode: &ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
params: &StartupMessageParams,
|
||||
common_names: Option<&HashSet<String>>,
|
||||
connect_compute_fn: F,
|
||||
) -> Result<(PostgresConnection, Stream<S>, Session), ClientRequestError> {
|
||||
// TODO: to pglb
|
||||
let hostname = mode.hostname(client.get_ref());
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth_backend
|
||||
@@ -319,14 +65,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
&mut client,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
@@ -339,7 +85,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return Err(stream
|
||||
return Err(client
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await)?;
|
||||
@@ -354,14 +100,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
|
||||
auth_info.set_startup_params(¶ms, params_compat);
|
||||
|
||||
let res = connect_to_compute(
|
||||
let res = connect_to_compute_pglb(
|
||||
config,
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: creds.info.clone(),
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
connect_compute_fn,
|
||||
&auth::Backend::ControlPlane(cplane, creds.info),
|
||||
&auth_info,
|
||||
&creds,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
@@ -369,32 +114,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
let node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
prepare_client_connection(&node, *session.key(), &mut client);
|
||||
let client = client.flush_and_into_inner().await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
Ok((node, client, session))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
@@ -429,7 +159,7 @@ impl NeonOptions {
|
||||
// proxy options:
|
||||
|
||||
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
|
||||
const PARAMS_COMPAT: &str = "proxy_params_compat";
|
||||
pub const PARAMS_COMPAT: &str = "proxy_params_compat";
|
||||
|
||||
// cplane options:
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
mod mitm;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, bail};
|
||||
@@ -20,16 +21,20 @@ use tracing_test::traced_test;
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::error::{ErrorKind, ReportableError};
|
||||
use crate::pglb::ERR_INSECURE_CONNECTION;
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
|
||||
use crate::stream::Stream;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
use crate::{auth, sasl, scram};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
|
||||
@@ -17,7 +17,8 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
||||
use crate::pglb::{ClientMode, handle_client};
|
||||
use crate::proxy::ErrorSource;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
|
||||
pin_project! {
|
||||
|
||||
Reference in New Issue
Block a user