mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 09:52:54 +00:00
1112 lines
38 KiB
Rust
1112 lines
38 KiB
Rust
#[cfg(test)]
|
|
mod tests;
|
|
|
|
use crate::{
|
|
auth,
|
|
cancellation::{self, CancelMap},
|
|
compute::{self, PostgresConnection},
|
|
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
|
|
console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
|
|
http::StatusCode,
|
|
protocol2::WithClientIp,
|
|
stream::{PqStream, Stream},
|
|
usage_metrics::{Ids, USAGE_METRICS},
|
|
};
|
|
use anyhow::{bail, Context};
|
|
use async_trait::async_trait;
|
|
use futures::TryFutureExt;
|
|
use itertools::Itertools;
|
|
use metrics::{exponential_buckets, register_int_counter_vec, IntCounterVec};
|
|
use once_cell::sync::{Lazy, OnceCell};
|
|
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
|
use prometheus::{
|
|
register_histogram, register_histogram_vec, register_int_gauge_vec, Histogram, HistogramVec,
|
|
IntGaugeVec,
|
|
};
|
|
use regex::Regex;
|
|
use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc, time::Instant};
|
|
use tokio::{
|
|
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
|
|
time,
|
|
};
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{error, info, info_span, warn, Instrument};
|
|
use utils::measured_stream::MeasuredStream;
|
|
|
|
/// Number of times we should retry the `/proxy_wake_compute` http request.
|
|
/// Retry duration is BASE_RETRY_WAIT_DURATION * RETRY_WAIT_EXPONENT_BASE ^ n, where n starts at 0
|
|
pub const NUM_RETRIES_CONNECT: u32 = 16;
|
|
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
|
const BASE_RETRY_WAIT_DURATION: time::Duration = time::Duration::from_millis(25);
|
|
const RETRY_WAIT_EXPONENT_BASE: f64 = std::f64::consts::SQRT_2;
|
|
|
|
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
|
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
|
|
|
pub static NUM_DB_CONNECTIONS_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_opened_db_connections_total",
|
|
"Number of opened connections to a database.",
|
|
&["protocol"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static NUM_DB_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_closed_db_connections_total",
|
|
"Number of closed connections to a database.",
|
|
&["protocol"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static NUM_CLIENT_CONNECTION_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_opened_client_connections_total",
|
|
"Number of opened connections from a client.",
|
|
&["protocol"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static NUM_CLIENT_CONNECTION_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_closed_client_connections_total",
|
|
"Number of closed connections from a client.",
|
|
&["protocol"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static NUM_CONNECTIONS_ACCEPTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_accepted_connections_total",
|
|
"Number of client connections accepted.",
|
|
&["protocol"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_closed_connections_total",
|
|
"Number of client connections closed.",
|
|
&["protocol"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
static COMPUTE_CONNECTION_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
|
|
register_histogram_vec!(
|
|
"proxy_compute_connection_latency_seconds",
|
|
"Time it took for proxy to establish a connection to the compute endpoint",
|
|
// http/ws/tcp, true/false, true/false, success/failure
|
|
// 3 * 2 * 2 * 2 = 24 counters
|
|
&["protocol", "cache_miss", "pool_miss", "outcome"],
|
|
// largest bucket = 2^16 * 0.5ms = 32s
|
|
exponential_buckets(0.0005, 2.0, 16).unwrap(),
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
static CONTROL_PLANE_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
|
|
register_histogram_vec!(
|
|
"proxy_compute_connection_control_plane_latency_seconds",
|
|
"Time proxy spent talking to control-plane/console while trying to establish a connection to the compute endpoint",
|
|
// http/ws/tcp, true/false, true/false, success/failure
|
|
// 3 * 2 * 2 * 2 = 24 counters
|
|
&["protocol", "cache_miss", "pool_miss", "outcome"],
|
|
// largest bucket = 2^16 * 0.5ms = 32s
|
|
exponential_buckets(0.0005, 2.0, 16).unwrap(),
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static CONSOLE_REQUEST_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
|
|
register_histogram_vec!(
|
|
"proxy_console_request_latency",
|
|
"Time it took for proxy to establish a connection to the compute endpoint",
|
|
// proxy_wake_compute/proxy_get_role_info
|
|
&["request"],
|
|
// largest bucket = 2^16 * 0.2ms = 13s
|
|
exponential_buckets(0.0002, 2.0, 16).unwrap(),
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static ALLOWED_IPS_BY_CACHE_OUTCOME: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_allowed_ips_cache_misses",
|
|
"Number of cache hits/misses for allowed ips",
|
|
// hit/miss
|
|
&["outcome"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static RATE_LIMITER_ACQUIRE_LATENCY: Lazy<Histogram> = Lazy::new(|| {
|
|
register_histogram!(
|
|
"proxy_control_plane_token_acquire_seconds",
|
|
"Time it took for proxy to establish a connection to the compute endpoint",
|
|
// largest bucket = 3^16 * 0.05ms = 2.15s
|
|
exponential_buckets(0.00005, 3.0, 16).unwrap(),
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static RATE_LIMITER_LIMIT: Lazy<IntGaugeVec> = Lazy::new(|| {
|
|
register_int_gauge_vec!(
|
|
"semaphore_control_plane_limit",
|
|
"Current limit of the semaphore control plane",
|
|
&["limit"], // 2 counters
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static NUM_CONNECTION_ACCEPTED_BY_SNI: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_accepted_connections_by_sni",
|
|
"Number of connections (per sni).",
|
|
&["kind"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
|
|
register_histogram!(
|
|
"proxy_allowed_ips_number",
|
|
"Number of allowed ips",
|
|
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub struct LatencyTimer {
|
|
// time since the stopwatch was started
|
|
start: Option<Instant>,
|
|
// accumulated time on the stopwatch
|
|
accumulated: std::time::Duration,
|
|
// time since the stopwatch was started while talking to control-plane
|
|
start_cp: Option<Instant>,
|
|
// accumulated time on the stopwatch while talking to control-plane
|
|
accumulated_cp: std::time::Duration,
|
|
// label data
|
|
protocol: &'static str,
|
|
cache_miss: bool,
|
|
pool_miss: bool,
|
|
outcome: &'static str,
|
|
}
|
|
|
|
pub struct LatencyTimerUserIO<'a> {
|
|
timer: &'a mut LatencyTimer,
|
|
}
|
|
|
|
pub struct LatencyTimerControlPlane<'a> {
|
|
timer: &'a mut LatencyTimer,
|
|
}
|
|
|
|
impl LatencyTimer {
|
|
pub fn new(protocol: &'static str) -> Self {
|
|
Self {
|
|
start: Some(Instant::now()),
|
|
accumulated: std::time::Duration::ZERO,
|
|
start_cp: None,
|
|
accumulated_cp: std::time::Duration::ZERO,
|
|
protocol,
|
|
cache_miss: false,
|
|
// by default we don't do pooling
|
|
pool_miss: true,
|
|
// assume failed unless otherwise specified
|
|
outcome: "failed",
|
|
}
|
|
}
|
|
|
|
pub fn control_plane(&mut self) -> LatencyTimerControlPlane<'_> {
|
|
// start the stopwatch again
|
|
self.start = Some(Instant::now());
|
|
LatencyTimerControlPlane { timer: self }
|
|
}
|
|
|
|
pub fn wait_for_user(&mut self) -> LatencyTimerUserIO<'_> {
|
|
// stop the stopwatch and record the time that we have accumulated
|
|
let start = self.start.take().expect("latency timer should be started");
|
|
self.accumulated += start.elapsed();
|
|
LatencyTimerUserIO { timer: self }
|
|
}
|
|
|
|
pub fn cache_miss(&mut self) {
|
|
self.cache_miss = true;
|
|
}
|
|
|
|
pub fn pool_hit(&mut self) {
|
|
self.pool_miss = false;
|
|
}
|
|
|
|
pub fn success(mut self) {
|
|
self.outcome = "success";
|
|
}
|
|
}
|
|
|
|
impl Drop for LatencyTimerUserIO<'_> {
|
|
fn drop(&mut self) {
|
|
// start the stopwatch again
|
|
self.timer.start = Some(Instant::now());
|
|
}
|
|
}
|
|
|
|
impl Drop for LatencyTimerControlPlane<'_> {
|
|
fn drop(&mut self) {
|
|
// stop the control-plane stopwatch and record the time that we have accumulated
|
|
let start = self
|
|
.timer
|
|
.start_cp
|
|
.take()
|
|
.expect("latency timer should be started");
|
|
self.timer.accumulated_cp += start.elapsed();
|
|
}
|
|
}
|
|
|
|
impl Drop for LatencyTimer {
|
|
fn drop(&mut self) {
|
|
let duration =
|
|
self.start.map(|start| start.elapsed()).unwrap_or_default() + self.accumulated;
|
|
COMPUTE_CONNECTION_LATENCY
|
|
.with_label_values(&[
|
|
self.protocol,
|
|
bool_to_str(self.cache_miss),
|
|
bool_to_str(self.pool_miss),
|
|
self.outcome,
|
|
])
|
|
.observe(duration.as_secs_f64());
|
|
|
|
let duration_cp = self
|
|
.start_cp
|
|
.map(|start| start.elapsed())
|
|
.unwrap_or_default()
|
|
+ self.accumulated_cp;
|
|
CONTROL_PLANE_LATENCY
|
|
.with_label_values(&[
|
|
self.protocol,
|
|
bool_to_str(self.cache_miss),
|
|
bool_to_str(self.pool_miss),
|
|
self.outcome,
|
|
])
|
|
.observe(duration_cp.as_secs_f64());
|
|
}
|
|
}
|
|
|
|
static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_connection_failures_total",
|
|
"Number of connection failures (per kind).",
|
|
&["kind"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
static NUM_WAKEUP_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_connection_failures_breakdown",
|
|
"Number of wake-up failures (per kind).",
|
|
&["retry", "kind"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
static NUM_BYTES_PROXIED_PER_CLIENT_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_io_bytes_per_client",
|
|
"Number of bytes sent/received between client and backend.",
|
|
crate::console::messages::MetricsAuxInfo::TRAFFIC_LABELS,
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
|
|
register_int_counter_vec!(
|
|
"proxy_io_bytes",
|
|
"Number of bytes sent/received between all clients and backends.",
|
|
&["direction"],
|
|
)
|
|
.unwrap()
|
|
});
|
|
|
|
pub async fn task_main(
|
|
config: &'static ProxyConfig,
|
|
listener: tokio::net::TcpListener,
|
|
cancellation_token: CancellationToken,
|
|
) -> anyhow::Result<()> {
|
|
scopeguard::defer! {
|
|
info!("proxy has shut down");
|
|
}
|
|
|
|
// When set for the server socket, the keepalive setting
|
|
// will be inherited by all accepted client sockets.
|
|
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
|
|
|
let mut connections = tokio::task::JoinSet::new();
|
|
let cancel_map = Arc::new(CancelMap::default());
|
|
|
|
loop {
|
|
tokio::select! {
|
|
accept_result = listener.accept() => {
|
|
let (socket, peer_addr) = accept_result?;
|
|
|
|
let session_id = uuid::Uuid::new_v4();
|
|
let cancel_map = Arc::clone(&cancel_map);
|
|
connections.spawn(
|
|
async move {
|
|
info!("accepted postgres client connection");
|
|
|
|
let mut socket = WithClientIp::new(socket);
|
|
let mut peer_addr = peer_addr;
|
|
if let Some(ip) = socket.wait_for_addr().await? {
|
|
peer_addr = ip;
|
|
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
|
|
} else if config.require_client_ip {
|
|
bail!("missing required client IP");
|
|
}
|
|
|
|
socket
|
|
.inner
|
|
.set_nodelay(true)
|
|
.context("failed to set socket option")?;
|
|
|
|
handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr.ip()).await
|
|
}
|
|
.instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
|
|
.unwrap_or_else(move |e| {
|
|
// Acknowledge that the task has finished with an error.
|
|
error!(?session_id, "per-client task finished with an error: {e:#}");
|
|
}),
|
|
);
|
|
}
|
|
// Don't modify this unless you read https://docs.rs/tokio/latest/tokio/macro.select.html carefully.
|
|
// If this future completes and the pattern doesn't match, this branch is disabled for this call to `select!`.
|
|
// This only counts for this loop and it will be enabled again on next `select!`.
|
|
//
|
|
// Prior code had this as `Some(Err(e))` which _looks_ equivalent to the current setup, but it's not.
|
|
// When `connections.join_next()` returned `Some(Ok(()))` (which we expect), it would disable the join_next and it would
|
|
// not get called again, even if there are more connections to remove.
|
|
Some(res) = connections.join_next() => {
|
|
if let Err(e) = res {
|
|
if !e.is_panic() && !e.is_cancelled() {
|
|
warn!("unexpected error from joined connection task: {e:?}");
|
|
}
|
|
}
|
|
}
|
|
_ = cancellation_token.cancelled() => {
|
|
drop(listener);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// Drain connections
|
|
while let Some(res) = connections.join_next().await {
|
|
if let Err(e) = res {
|
|
if !e.is_panic() && !e.is_cancelled() {
|
|
warn!("unexpected error from joined connection task: {e:?}");
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub enum ClientMode {
|
|
Tcp,
|
|
Websockets { hostname: Option<String> },
|
|
}
|
|
|
|
/// Abstracts the logic of handling TCP vs WS clients
|
|
impl ClientMode {
|
|
fn protocol_label(&self) -> &'static str {
|
|
match self {
|
|
ClientMode::Tcp => "tcp",
|
|
ClientMode::Websockets { .. } => "ws",
|
|
}
|
|
}
|
|
|
|
fn allow_cleartext(&self) -> bool {
|
|
match self {
|
|
ClientMode::Tcp => false,
|
|
ClientMode::Websockets { .. } => true,
|
|
}
|
|
}
|
|
|
|
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
|
match self {
|
|
ClientMode::Tcp => config.allow_self_signed_compute,
|
|
ClientMode::Websockets { .. } => false,
|
|
}
|
|
}
|
|
|
|
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
|
match self {
|
|
ClientMode::Tcp => s.sni_hostname(),
|
|
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
|
}
|
|
}
|
|
|
|
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
|
match self {
|
|
ClientMode::Tcp => tls,
|
|
// TLS is None here if using websockets, because the connection is already encrypted.
|
|
ClientMode::Websockets { .. } => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
|
config: &'static ProxyConfig,
|
|
cancel_map: &CancelMap,
|
|
session_id: uuid::Uuid,
|
|
stream: S,
|
|
mode: ClientMode,
|
|
peer_addr: IpAddr,
|
|
) -> anyhow::Result<()> {
|
|
info!(
|
|
protocol = mode.protocol_label(),
|
|
"handling interactive connection from client"
|
|
);
|
|
|
|
let proto = mode.protocol_label();
|
|
NUM_CLIENT_CONNECTION_OPENED_COUNTER
|
|
.with_label_values(&[proto])
|
|
.inc();
|
|
NUM_CONNECTIONS_ACCEPTED_COUNTER
|
|
.with_label_values(&[proto])
|
|
.inc();
|
|
scopeguard::defer! {
|
|
NUM_CLIENT_CONNECTION_CLOSED_COUNTER.with_label_values(&[proto]).inc();
|
|
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
|
|
}
|
|
|
|
let tls = config.tls_config.as_ref();
|
|
|
|
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
|
|
let (mut stream, params) = match do_handshake.await? {
|
|
Some(x) => x,
|
|
None => return Ok(()), // it's a cancellation request
|
|
};
|
|
|
|
// Extract credentials which we're going to use for auth.
|
|
let creds = {
|
|
let hostname = mode.hostname(stream.get_ref());
|
|
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
|
let result = config
|
|
.auth_backend
|
|
.as_ref()
|
|
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names, peer_addr))
|
|
.transpose();
|
|
|
|
match result {
|
|
Ok(creds) => creds,
|
|
Err(e) => stream.throw_error(e).await?,
|
|
}
|
|
};
|
|
|
|
let client = Client::new(
|
|
stream,
|
|
creds,
|
|
¶ms,
|
|
session_id,
|
|
mode.allow_self_signed_compute(config),
|
|
);
|
|
cancel_map
|
|
.with_session(|session| client.connect_to_db(session, mode, &config.authentication_config))
|
|
.await
|
|
}
|
|
|
|
/// Establish a (most probably, secure) connection with the client.
|
|
/// For better testing experience, `stream` can be any object satisfying the traits.
|
|
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
|
/// we also take an extra care of propagating only the select handshake errors to client.
|
|
#[tracing::instrument(skip_all)]
|
|
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
|
stream: S,
|
|
mut tls: Option<&TlsConfig>,
|
|
cancel_map: &CancelMap,
|
|
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
|
// Client may try upgrading to each protocol only once
|
|
let (mut tried_ssl, mut tried_gss) = (false, false);
|
|
|
|
let mut stream = PqStream::new(Stream::from_raw(stream));
|
|
loop {
|
|
let msg = stream.read_startup_packet().await?;
|
|
info!("received {msg:?}");
|
|
|
|
use FeStartupPacket::*;
|
|
match msg {
|
|
SslRequest => match stream.get_ref() {
|
|
Stream::Raw { .. } if !tried_ssl => {
|
|
tried_ssl = true;
|
|
|
|
// We can't perform TLS handshake without a config
|
|
let enc = tls.is_some();
|
|
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
|
if let Some(tls) = tls.take() {
|
|
// Upgrade raw stream into a secure TLS-backed stream.
|
|
// NOTE: We've consumed `tls`; this fact will be used later.
|
|
|
|
let (raw, read_buf) = stream.into_inner();
|
|
// TODO: Normally, client doesn't send any data before
|
|
// server says TLS handshake is ok and read_buf is empy.
|
|
// However, you could imagine pipelining of postgres
|
|
// SSLRequest + TLS ClientHello in one hunk similar to
|
|
// pipelining in our node js driver. We should probably
|
|
// support that by chaining read_buf with the stream.
|
|
if !read_buf.is_empty() {
|
|
bail!("data is sent before server replied with EncryptionResponse");
|
|
}
|
|
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
|
|
|
let (_, tls_server_end_point) = tls
|
|
.cert_resolver
|
|
.resolve(tls_stream.get_ref().1.server_name())
|
|
.context("missing certificate")?;
|
|
|
|
stream = PqStream::new(Stream::Tls {
|
|
tls: Box::new(tls_stream),
|
|
tls_server_end_point,
|
|
});
|
|
}
|
|
}
|
|
_ => bail!(ERR_PROTO_VIOLATION),
|
|
},
|
|
GssEncRequest => match stream.get_ref() {
|
|
Stream::Raw { .. } if !tried_gss => {
|
|
tried_gss = true;
|
|
|
|
// Currently, we don't support GSSAPI
|
|
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
|
}
|
|
_ => bail!(ERR_PROTO_VIOLATION),
|
|
},
|
|
StartupMessage { params, .. } => {
|
|
// Check that the config has been consumed during upgrade
|
|
// OR we didn't provide it at all (for dev purposes).
|
|
if tls.is_some() {
|
|
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
|
}
|
|
|
|
info!(session_type = "normal", "successful handshake");
|
|
break Ok(Some((stream, params)));
|
|
}
|
|
CancelRequest(cancel_key_data) => {
|
|
cancel_map.cancel_session(cancel_key_data).await?;
|
|
|
|
info!(session_type = "cancellation", "successful handshake");
|
|
break Ok(None);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// If we couldn't connect, a cached connection info might be to blame
|
|
/// (e.g. the compute node's address might've changed at the wrong time).
|
|
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
|
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
|
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
|
|
let is_cached = node_info.cached();
|
|
if is_cached {
|
|
warn!("invalidating stalled compute node info cache entry");
|
|
}
|
|
let label = match is_cached {
|
|
true => "compute_cached",
|
|
false => "compute_uncached",
|
|
};
|
|
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
|
|
|
node_info.invalidate().config
|
|
}
|
|
|
|
/// Try to connect to the compute node once.
|
|
#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)]
|
|
async fn connect_to_compute_once(
|
|
node_info: &console::CachedNodeInfo,
|
|
timeout: time::Duration,
|
|
) -> Result<PostgresConnection, compute::ConnectionError> {
|
|
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
|
|
|
node_info
|
|
.config
|
|
.connect(allow_self_signed_compute, timeout)
|
|
.await
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait ConnectMechanism {
|
|
type Connection;
|
|
type ConnectError;
|
|
type Error: From<Self::ConnectError>;
|
|
async fn connect_once(
|
|
&self,
|
|
node_info: &console::CachedNodeInfo,
|
|
timeout: time::Duration,
|
|
) -> Result<Self::Connection, Self::ConnectError>;
|
|
|
|
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
|
}
|
|
|
|
pub struct TcpMechanism<'a> {
|
|
/// KV-dictionary with PostgreSQL connection params.
|
|
pub params: &'a StartupMessageParams,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ConnectMechanism for TcpMechanism<'_> {
|
|
type Connection = PostgresConnection;
|
|
type ConnectError = compute::ConnectionError;
|
|
type Error = compute::ConnectionError;
|
|
|
|
async fn connect_once(
|
|
&self,
|
|
node_info: &console::CachedNodeInfo,
|
|
timeout: time::Duration,
|
|
) -> Result<PostgresConnection, Self::Error> {
|
|
connect_to_compute_once(node_info, timeout).await
|
|
}
|
|
|
|
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
|
config.set_startup_params(self.params);
|
|
}
|
|
}
|
|
|
|
const fn bool_to_str(x: bool) -> &'static str {
|
|
if x {
|
|
"true"
|
|
} else {
|
|
"false"
|
|
}
|
|
}
|
|
|
|
fn report_error(e: &WakeComputeError, retry: bool) {
|
|
use crate::console::errors::ApiError;
|
|
let retry = bool_to_str(retry);
|
|
let kind = match e {
|
|
WakeComputeError::BadComputeAddress(_) => "bad_compute_address",
|
|
WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error",
|
|
WakeComputeError::ApiError(ApiError::Console {
|
|
status: StatusCode::LOCKED,
|
|
ref text,
|
|
}) if text.contains("written data quota exceeded")
|
|
|| text.contains("the limit for current plan reached") =>
|
|
{
|
|
"quota_exceeded"
|
|
}
|
|
WakeComputeError::ApiError(ApiError::Console {
|
|
status: StatusCode::LOCKED,
|
|
..
|
|
}) => "api_console_locked",
|
|
WakeComputeError::ApiError(ApiError::Console {
|
|
status: StatusCode::BAD_REQUEST,
|
|
..
|
|
}) => "api_console_bad_request",
|
|
WakeComputeError::ApiError(ApiError::Console { status, .. })
|
|
if status.is_server_error() =>
|
|
{
|
|
"api_console_other_server_error"
|
|
}
|
|
WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
|
|
WakeComputeError::TimeoutError => "timeout_error",
|
|
};
|
|
NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
|
|
}
|
|
|
|
/// Try to connect to the compute node, retrying if necessary.
|
|
/// This function might update `node_info`, so we take it by `&mut`.
|
|
#[tracing::instrument(skip_all)]
|
|
pub async fn connect_to_compute<M: ConnectMechanism>(
|
|
mechanism: &M,
|
|
mut node_info: console::CachedNodeInfo,
|
|
extra: &console::ConsoleReqExtra<'_>,
|
|
creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
|
mut latency_timer: LatencyTimer,
|
|
) -> Result<M::Connection, M::Error>
|
|
where
|
|
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
|
M::Error: From<WakeComputeError>,
|
|
{
|
|
mechanism.update_connect_config(&mut node_info.config);
|
|
|
|
// try once
|
|
let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
|
Ok(res) => {
|
|
latency_timer.success();
|
|
return Ok(res);
|
|
}
|
|
Err(e) => {
|
|
error!(error = ?e, "could not connect to compute node");
|
|
(invalidate_cache(node_info), e)
|
|
}
|
|
};
|
|
|
|
latency_timer.cache_miss();
|
|
|
|
let mut num_retries = 1;
|
|
|
|
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
|
info!("compute node's state has likely changed; requesting a wake-up");
|
|
let node_info = loop {
|
|
let wake_res = match creds {
|
|
auth::BackendType::Console(api, creds) => {
|
|
api.wake_compute(extra, creds, &mut latency_timer).await
|
|
}
|
|
#[cfg(feature = "testing")]
|
|
auth::BackendType::Postgres(api, creds) => {
|
|
api.wake_compute(extra, creds, &mut latency_timer).await
|
|
}
|
|
// nothing to do?
|
|
auth::BackendType::Link(_) => return Err(err.into()),
|
|
// test backend
|
|
#[cfg(test)]
|
|
auth::BackendType::Test(x) => x.wake_compute(),
|
|
};
|
|
|
|
match handle_try_wake(wake_res, num_retries) {
|
|
Err(e) => {
|
|
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
|
report_error(&e, false);
|
|
return Err(e.into());
|
|
}
|
|
// failed to wake up but we can continue to retry
|
|
Ok(ControlFlow::Continue(e)) => {
|
|
report_error(&e, true);
|
|
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
|
|
}
|
|
// successfully woke up a compute node and can break the wakeup loop
|
|
Ok(ControlFlow::Break(mut node_info)) => {
|
|
node_info.config.reuse_password(&config);
|
|
mechanism.update_connect_config(&mut node_info.config);
|
|
break node_info;
|
|
}
|
|
}
|
|
|
|
let wait_duration = retry_after(num_retries);
|
|
num_retries += 1;
|
|
|
|
time::sleep(wait_duration).await;
|
|
};
|
|
|
|
// now that we have a new node, try connect to it repeatedly.
|
|
// this can error for a few reasons, for instance:
|
|
// * DNS connection settings haven't quite propagated yet
|
|
info!("wake_compute success. attempting to connect");
|
|
loop {
|
|
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
|
Ok(res) => {
|
|
latency_timer.success();
|
|
return Ok(res);
|
|
}
|
|
Err(e) => {
|
|
let retriable = e.should_retry(num_retries);
|
|
if !retriable {
|
|
error!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
|
|
return Err(e.into());
|
|
}
|
|
warn!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
|
|
}
|
|
}
|
|
|
|
let wait_duration = retry_after(num_retries);
|
|
num_retries += 1;
|
|
|
|
time::sleep(wait_duration).await;
|
|
}
|
|
}
|
|
|
|
/// Attempts to wake up the compute node.
|
|
/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
|
|
/// * Returns Ok(Break(node)) if the wakeup succeeded
|
|
/// * Returns Err(e) if there was an error
|
|
pub fn handle_try_wake(
|
|
result: Result<console::CachedNodeInfo, WakeComputeError>,
|
|
num_retries: u32,
|
|
) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
|
|
match result {
|
|
Err(err) => match &err {
|
|
WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
|
|
Ok(ControlFlow::Continue(err))
|
|
}
|
|
_ => Err(err),
|
|
},
|
|
// Ready to try again.
|
|
Ok(new) => Ok(ControlFlow::Break(new)),
|
|
}
|
|
}
|
|
|
|
pub trait ShouldRetry {
|
|
fn could_retry(&self) -> bool;
|
|
fn should_retry(&self, num_retries: u32) -> bool {
|
|
match self {
|
|
_ if num_retries >= NUM_RETRIES_CONNECT => false,
|
|
err => err.could_retry(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ShouldRetry for io::Error {
|
|
fn could_retry(&self) -> bool {
|
|
use std::io::ErrorKind;
|
|
matches!(
|
|
self.kind(),
|
|
ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
|
|
)
|
|
}
|
|
}
|
|
|
|
impl ShouldRetry for tokio_postgres::error::DbError {
|
|
fn could_retry(&self) -> bool {
|
|
use tokio_postgres::error::SqlState;
|
|
matches!(
|
|
self.code(),
|
|
&SqlState::CONNECTION_FAILURE
|
|
| &SqlState::CONNECTION_EXCEPTION
|
|
| &SqlState::CONNECTION_DOES_NOT_EXIST
|
|
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
|
|
)
|
|
}
|
|
}
|
|
|
|
impl ShouldRetry for tokio_postgres::Error {
|
|
fn could_retry(&self) -> bool {
|
|
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
|
|
io::Error::could_retry(io_err)
|
|
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
|
|
tokio_postgres::error::DbError::could_retry(db_err)
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ShouldRetry for compute::ConnectionError {
|
|
fn could_retry(&self) -> bool {
|
|
match self {
|
|
compute::ConnectionError::Postgres(err) => err.could_retry(),
|
|
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
|
|
_ => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn retry_after(num_retries: u32) -> time::Duration {
|
|
BASE_RETRY_WAIT_DURATION.mul_f64(RETRY_WAIT_EXPONENT_BASE.powi((num_retries as i32) - 1))
|
|
}
|
|
|
|
/// Finish client connection initialization: confirm auth success, send params, etc.
|
|
#[tracing::instrument(skip_all)]
|
|
async fn prepare_client_connection(
|
|
node: &compute::PostgresConnection,
|
|
session: cancellation::Session<'_>,
|
|
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
|
) -> anyhow::Result<()> {
|
|
// Register compute's query cancellation token and produce a new, unique one.
|
|
// The new token (cancel_key_data) will be sent to the client.
|
|
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
|
|
|
// Forward all postgres connection params to the client.
|
|
// Right now the implementation is very hacky and inefficent (ideally,
|
|
// we don't need an intermediate hashmap), but at least it should be correct.
|
|
for (name, value) in &node.params {
|
|
// TODO: Theoretically, this could result in a big pile of params...
|
|
stream.write_message_noflush(&Be::ParameterStatus {
|
|
name: name.as_bytes(),
|
|
value: value.as_bytes(),
|
|
})?;
|
|
}
|
|
|
|
stream
|
|
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
|
.write_message(&Be::ReadyForQuery)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Forward bytes in both directions (client <-> compute).
|
|
#[tracing::instrument(skip_all)]
|
|
pub async fn proxy_pass(
|
|
client: impl AsyncRead + AsyncWrite + Unpin,
|
|
compute: impl AsyncRead + AsyncWrite + Unpin,
|
|
aux: MetricsAuxInfo,
|
|
) -> anyhow::Result<()> {
|
|
let usage = USAGE_METRICS.register(Ids {
|
|
endpoint_id: aux.endpoint_id.clone(),
|
|
branch_id: aux.branch_id.clone(),
|
|
});
|
|
|
|
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
|
|
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
|
|
let mut client = MeasuredStream::new(
|
|
client,
|
|
|_| {},
|
|
|cnt| {
|
|
// Number of bytes we sent to the client (outbound).
|
|
m_sent.inc_by(cnt as u64);
|
|
m_sent2.inc_by(cnt as u64);
|
|
usage.record_egress(cnt as u64);
|
|
},
|
|
);
|
|
|
|
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
|
|
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
|
|
let mut compute = MeasuredStream::new(
|
|
compute,
|
|
|_| {},
|
|
|cnt| {
|
|
// Number of bytes the client sent to the compute node (inbound).
|
|
m_recv.inc_by(cnt as u64);
|
|
m_recv2.inc_by(cnt as u64);
|
|
},
|
|
);
|
|
|
|
// Starting from here we only proxy the client's traffic.
|
|
info!("performing the proxy pass...");
|
|
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Thin connection context.
|
|
struct Client<'a, S> {
|
|
/// The underlying libpq protocol stream.
|
|
stream: PqStream<Stream<S>>,
|
|
/// Client credentials that we care about.
|
|
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
|
/// KV-dictionary with PostgreSQL connection params.
|
|
params: &'a StartupMessageParams,
|
|
/// Unique connection ID.
|
|
session_id: uuid::Uuid,
|
|
/// Allow self-signed certificates (for testing).
|
|
allow_self_signed_compute: bool,
|
|
}
|
|
|
|
impl<'a, S> Client<'a, S> {
|
|
/// Construct a new connection context.
|
|
fn new(
|
|
stream: PqStream<Stream<S>>,
|
|
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
|
params: &'a StartupMessageParams,
|
|
session_id: uuid::Uuid,
|
|
allow_self_signed_compute: bool,
|
|
) -> Self {
|
|
Self {
|
|
stream,
|
|
creds,
|
|
params,
|
|
session_id,
|
|
allow_self_signed_compute,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
|
/// Let the client authenticate and connect to the designated compute node.
|
|
// Instrumentation logs endpoint name everywhere. Doesn't work for link
|
|
// auth; strictly speaking we don't know endpoint name in its case.
|
|
#[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)]
|
|
async fn connect_to_db(
|
|
self,
|
|
session: cancellation::Session<'_>,
|
|
mode: ClientMode,
|
|
config: &'static AuthenticationConfig,
|
|
) -> anyhow::Result<()> {
|
|
let Self {
|
|
mut stream,
|
|
creds,
|
|
params,
|
|
session_id,
|
|
allow_self_signed_compute,
|
|
} = self;
|
|
|
|
let console_options = neon_options(params);
|
|
|
|
let extra = console::ConsoleReqExtra {
|
|
session_id, // aka this connection's id
|
|
application_name: params.get("application_name"),
|
|
options: console_options.as_deref(),
|
|
};
|
|
|
|
let mut latency_timer = LatencyTimer::new(mode.protocol_label());
|
|
|
|
let user = creds.get_user().to_owned();
|
|
let auth_result = match creds
|
|
.authenticate(
|
|
&extra,
|
|
&mut stream,
|
|
mode.allow_cleartext(),
|
|
config,
|
|
&mut latency_timer,
|
|
)
|
|
.await
|
|
{
|
|
Ok(auth_result) => auth_result,
|
|
Err(e) => {
|
|
let db = params.get("database");
|
|
let app = params.get("application_name");
|
|
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
|
|
|
return stream.throw_error(e).instrument(params_span).await;
|
|
}
|
|
};
|
|
|
|
let (mut node_info, creds) = auth_result;
|
|
|
|
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
|
|
|
let aux = node_info.aux.clone();
|
|
let mut node = connect_to_compute(
|
|
&TcpMechanism { params },
|
|
node_info,
|
|
&extra,
|
|
&creds,
|
|
latency_timer,
|
|
)
|
|
.or_else(|e| stream.throw_error(e))
|
|
.await?;
|
|
|
|
let proto = mode.protocol_label();
|
|
NUM_DB_CONNECTIONS_OPENED_COUNTER
|
|
.with_label_values(&[proto])
|
|
.inc();
|
|
scopeguard::defer! {
|
|
NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
|
|
}
|
|
|
|
prepare_client_connection(&node, session, &mut stream).await?;
|
|
// Before proxy passing, forward to compute whatever data is left in the
|
|
// PqStream input buffer. Normally there is none, but our serverless npm
|
|
// driver in pipeline mode sends startup, password and first query
|
|
// immediately after opening the connection.
|
|
let (stream, read_buf) = stream.into_inner();
|
|
node.stream.write_all(&read_buf).await?;
|
|
proxy_pass(stream, node.stream, aux).await
|
|
}
|
|
}
|
|
|
|
pub fn neon_options(params: &StartupMessageParams) -> Option<String> {
|
|
#[allow(unstable_name_collisions)]
|
|
let options: String = params
|
|
.options_raw()?
|
|
.filter(|opt| is_neon_param(opt))
|
|
.sorted() // we sort it to use as cache key
|
|
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
|
.collect();
|
|
|
|
// Don't even bother with empty options.
|
|
if options.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
Some(options)
|
|
}
|
|
|
|
pub fn is_neon_param(bytes: &str) -> bool {
|
|
static RE: OnceCell<Regex> = OnceCell::new();
|
|
RE.get_or_init(|| Regex::new(r"^neon_\w+:").unwrap());
|
|
|
|
RE.get().unwrap().is_match(bytes)
|
|
}
|