mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-31 17:20:37 +00:00
Compare commits
16 Commits
diko/baseb
...
conrad/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
631139ceeb | ||
|
|
fd43058bd7 | ||
|
|
cf07c5b5f9 | ||
|
|
11bb84c38d | ||
|
|
219c72c24c | ||
|
|
0633cd6385 | ||
|
|
0cdb0c5704 | ||
|
|
eefac5d78b | ||
|
|
7d1c908b1b | ||
|
|
cfa2813446 | ||
|
|
034bdb1552 | ||
|
|
8b1ffa1718 | ||
|
|
2d3ea77953 | ||
|
|
3124729f53 | ||
|
|
6463eb38be | ||
|
|
ae506fd791 |
@@ -17,9 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::{
|
||||
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
|
||||
};
|
||||
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
@@ -137,16 +135,6 @@ impl<'a, T> Backend<'a, T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, T, E> Backend<'a, Result<T, E>> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
|
||||
match self {
|
||||
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
|
||||
Self::Local(l) => Ok(Backend::Local(l)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ComputeCredentials {
|
||||
pub(crate) info: ComputeUserInfo,
|
||||
@@ -284,7 +272,7 @@ async fn auth_quirks(
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
@@ -301,15 +289,12 @@ async fn auth_quirks(
|
||||
debug!("fetching authentication info and allowlists");
|
||||
|
||||
// check allowed list
|
||||
let allowed_ips = if config.ip_allowlist_check_enabled {
|
||||
if config.ip_allowlist_check_enabled {
|
||||
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
allowed_ips
|
||||
} else {
|
||||
Cached::new_uncached(Arc::new(vec![]))
|
||||
};
|
||||
}
|
||||
|
||||
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
|
||||
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
|
||||
@@ -368,7 +353,7 @@ async fn auth_quirks(
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
|
||||
Ok(keys) => Ok(keys),
|
||||
Err(e) => {
|
||||
if e.is_password_failed() {
|
||||
// The password could have been changed, so we invalidate the cache.
|
||||
@@ -420,53 +405,39 @@ async fn authenticate_with_secret(
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get username from the credentials.
|
||||
pub(crate) fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::ControlPlane(_, user_info) => &user_info.user,
|
||||
Self::Local(_) => "local",
|
||||
}
|
||||
}
|
||||
|
||||
impl ControlPlaneClient {
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub(crate) async fn authenticate(
|
||||
self,
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
|
||||
let res = match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
debug!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
debug!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let (credentials, ip_allowlist) = auth_quirks(
|
||||
ctx,
|
||||
&*api,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
|
||||
}
|
||||
Self::Local(_) => {
|
||||
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
|
||||
}
|
||||
};
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
self,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// TODO: replace with some metric
|
||||
info!("user successfully authenticated");
|
||||
res
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -536,6 +507,25 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ControlPlaneWakeCompute<'a> {
|
||||
pub cplane: &'a ControlPlaneClient,
|
||||
pub creds: ComputeCredentials,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
self.cplane.wake_compute(ctx, &self.creds.info).await
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&self.creds.keys
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unimplemented, clippy::unwrap_used)]
|
||||
@@ -552,6 +542,7 @@ mod tests {
|
||||
use postgres_protocol::message::backend::Message as PgMessage;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use super::jwt::JwkCache;
|
||||
use super::{AuthRateLimiter, auth_quirks};
|
||||
@@ -702,7 +693,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -784,7 +775,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -838,7 +829,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -887,7 +878,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(creds.0.info.endpoint, "my-endpoint");
|
||||
assert_eq!(creds.info.endpoint, "my-endpoint");
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::Backend;
|
||||
pub use backend::{Backend, ControlPlaneWakeCompute};
|
||||
|
||||
mod credentials;
|
||||
pub(crate) use credentials::{
|
||||
|
||||
@@ -18,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
@@ -226,7 +227,8 @@ pub(super) async fn task_main(
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
let compute_tls_config = compute_tls_config.clone();
|
||||
|
||||
connections.spawn(
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(
|
||||
async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
@@ -249,6 +251,7 @@ pub(super) async fn task_main(
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
tracker,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -274,10 +277,11 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
raw_stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::SslRequest;
|
||||
@@ -291,7 +295,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
let (raw, read_buf, tracker) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empty.
|
||||
// However, you could imagine pipelining of postgres
|
||||
@@ -302,13 +306,16 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
Ok((
|
||||
Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
},
|
||||
tracker,
|
||||
))
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
@@ -329,8 +336,10 @@ async fn handle_client(
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
let (mut tls_stream, _tracker) =
|
||||
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -323,7 +323,7 @@ impl CancellationHandler {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
|
||||
pub(crate) fn get_key(self: Arc<Self>) -> Session {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
@@ -340,7 +340,7 @@ impl CancellationHandler {
|
||||
Session {
|
||||
key,
|
||||
redis_key,
|
||||
cancellation_handler: Arc::clone(self),
|
||||
cancellation_handler: self,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures::TryFutureExt;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
@@ -14,10 +15,8 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
use crate::proxy::{
|
||||
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
|
||||
};
|
||||
use crate::proxy::passthrough::passthrough;
|
||||
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -35,7 +34,6 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -49,11 +47,11 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
connections.spawn(async move {
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
@@ -103,99 +101,80 @@ pub async fn task_main(
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let span = ctx.span();
|
||||
let mut slot = Some(ctx);
|
||||
let res = handle_client(
|
||||
config,
|
||||
backend,
|
||||
&ctx,
|
||||
&mut slot,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
tracker,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.instrument(span)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
match (slot, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.success();
|
||||
}
|
||||
(Some(ctx), Err(e)) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx: &RequestContext,
|
||||
ctx_slot: &mut Option<RequestContext>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<(), ClientRequestError> {
|
||||
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
|
||||
debug!(%protocol, "handling interactive connection from client");
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
let request_gauge = metrics.connection_requests.guard(protocol);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
|
||||
let data = {
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
|
||||
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
|
||||
};
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
let (mut stream, params) = match data {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
tokio::spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
let _tracker = tracker;
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
@@ -205,15 +184,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
|
||||
.ok();
|
||||
}
|
||||
.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, user_info, _ip_allowlist) = match backend
|
||||
@@ -228,13 +209,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
TcpMechanism {
|
||||
user_info,
|
||||
params_compat: true,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&node_info,
|
||||
node_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
@@ -252,17 +233,22 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
tokio::spawn(passthrough(
|
||||
ctx,
|
||||
&config.connect_to_compute,
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ pub struct RequestContext(
|
||||
/// I would typically use a RefCell but that would break the `Send` requirements
|
||||
/// so we need something with thread-safety. `TryLock` is a cheap alternative
|
||||
/// that offers similar semantics to a `RefCell` but with synchronisation.
|
||||
TryLock<RequestContextInner>,
|
||||
TryLock<Box<RequestContextInner>>,
|
||||
);
|
||||
|
||||
struct RequestContextInner {
|
||||
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
|
||||
impl Clone for RequestContext {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = self.0.try_lock().expect("should not deadlock");
|
||||
let new = RequestContextInner {
|
||||
let new = Box::new(RequestContextInner {
|
||||
conn_info: inner.conn_info.clone(),
|
||||
session_id: inner.session_id,
|
||||
protocol: inner.protocol,
|
||||
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
|
||||
disconnect_sender: None,
|
||||
latency_timer: LatencyTimer::noop(inner.protocol),
|
||||
disconnect_timestamp: inner.disconnect_timestamp,
|
||||
};
|
||||
});
|
||||
|
||||
Self(TryLock::new(new))
|
||||
}
|
||||
@@ -140,7 +140,7 @@ impl RequestContext {
|
||||
role = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let inner = RequestContextInner {
|
||||
let inner = Box::new(RequestContextInner {
|
||||
conn_info,
|
||||
session_id,
|
||||
protocol,
|
||||
@@ -168,7 +168,7 @@ impl RequestContext {
|
||||
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
|
||||
latency_timer: LatencyTimer::new(protocol),
|
||||
disconnect_timestamp: None,
|
||||
};
|
||||
});
|
||||
|
||||
Self(TryLock::new(inner))
|
||||
}
|
||||
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DisconnectLogger(RequestContextInner);
|
||||
pub struct DisconnectLogger(Box<RequestContextInner>);
|
||||
|
||||
impl Drop for DisconnectLogger {
|
||||
fn drop(&mut self) {
|
||||
|
||||
@@ -53,6 +53,25 @@ pub(crate) trait ConnectMechanism {
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: ConnectMechanism + Sync> ConnectMechanism for &T {
|
||||
type Connection = T::Connection;
|
||||
type ConnectError = T::ConnectError;
|
||||
type Error = T::Error;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
T::connect_once(self, ctx, node_info, config).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg) {
|
||||
T::update_connect_config(self, conf);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
@@ -105,8 +124,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &RequestContext,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
mechanism: M,
|
||||
backend: B,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
compute: &ComputeConfig,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
@@ -116,9 +135,9 @@ where
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
|
||||
|
||||
node_info.set_keys(user_info.get_keys());
|
||||
node_info.set_keys(backend.get_keys());
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
@@ -159,7 +178,7 @@ where
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
// TODO: increment num_retries?
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
@@ -67,7 +67,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
|
||||
@@ -5,6 +5,7 @@ use pq_proto::{
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::endpoint_sni;
|
||||
@@ -51,7 +52,7 @@ impl ReportableError for HandshakeError {
|
||||
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
Cancel(CancelKeyData, TaskTrackerToken),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
@@ -62,6 +63,7 @@ pub(crate) enum HandshakeData<S> {
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<HandshakeData<S>, HandshakeError> {
|
||||
@@ -71,7 +73,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
match msg {
|
||||
@@ -157,15 +159,13 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
stream.framed = Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,26 +10,27 @@ pub(crate) mod wake_compute;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use passthrough::passthrough;
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use self::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -70,7 +71,6 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -84,12 +84,12 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
let (socket, conn_info) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
@@ -138,60 +138,41 @@ pub async fn task_main(
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span();
|
||||
let mut ctx = Some(ctx);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
&mut ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
tracker,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.instrument(span)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
match (ctx, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.success();
|
||||
}
|
||||
(Some(ctx), Err(e)) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -258,46 +239,79 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
ctx_slot: &mut Option<RequestContext>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<(), ClientRequestError> {
|
||||
let cplane = match auth_backend {
|
||||
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
|
||||
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
|
||||
debug!(%protocol, "handling interactive connection from client");
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
let request_gauge = metrics.connection_requests.guard(protocol);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
let handshake_result: Result<_, ClientRequestError> = async {
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let data = {
|
||||
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
tokio::time::timeout(
|
||||
config.handshake_timeout,
|
||||
handshake(
|
||||
ctx,
|
||||
stream,
|
||||
tracker,
|
||||
mode.handshake_tls(tls),
|
||||
record_handshake_error,
|
||||
),
|
||||
)
|
||||
.await??
|
||||
};
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
match data {
|
||||
HandshakeData::Startup(mut stream, params) => {
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let host = mode.hostname(stream.get_ref());
|
||||
let cn = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, host, cn);
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
};
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
Ok(Some((stream, params, session, user_info)))
|
||||
}
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
tokio::spawn(async move {
|
||||
// ensure the proxy doesn't shutdown until we complete this task.
|
||||
let _tracker = tracker;
|
||||
|
||||
cancellation_handler
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
@@ -305,111 +319,108 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.instrument(cancel_span)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
let Some((mut stream, params, session, user_info)) = handshake_result? else {
|
||||
return Ok(());
|
||||
};
|
||||
drop(pause);
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
let auth_result: Result<_, ClientRequestError> = async {
|
||||
let user = user_info.user.clone();
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
match cplane
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
user_info,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => Ok(auth_result),
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
let compute_creds = auth_result?;
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth_backend
|
||||
.as_ref()
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
let connect_result: Result<_, ClientRequestError> = async {
|
||||
let compute_user_info = compute_creds.info.clone();
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let (user_info, _ip_allowlist) = match user_info
|
||||
.authenticate(
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
TcpMechanism {
|
||||
user_info: compute_user_info,
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
auth::ControlPlaneWakeCompute {
|
||||
cplane,
|
||||
creds: compute_creds,
|
||||
},
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
|
||||
return stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?;
|
||||
}
|
||||
};
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
let compute_user_info = match &user_info {
|
||||
auth::Backend::ControlPlane(_, info) => &info.info,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
Ok((node, stream, tracker))
|
||||
}
|
||||
.await;
|
||||
|
||||
let (node, stream, tracker) = connect_result?;
|
||||
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
tokio::spawn(passthrough(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: compute_user_info.clone(),
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use smol_str::SmolStr;
|
||||
use smol_str::{SmolStr, ToSmolStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
@@ -7,13 +8,14 @@ use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
@@ -61,41 +63,53 @@ pub(crate) async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
ctx: RequestContext,
|
||||
compute_config: &'static ComputeConfig,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
}
|
||||
client: Stream<S>,
|
||||
compute: PostgresConnection,
|
||||
cancel: cancellation::Session,
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
_req: NumConnectionRequestsGuard<'static>,
|
||||
_conn: NumClientConnectionsGuard<'static>,
|
||||
_tracker: TaskTrackerToken,
|
||||
) {
|
||||
let session_id = ctx.session_id();
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let _disconnect = ctx.log_connect();
|
||||
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
|
||||
|
||||
match res {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
tracing::warn!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
tracing::error!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
|
||||
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
if let Err(err) = compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
// we don't need a result. If the queue is full, we just log the error
|
||||
drop(cancel.remove_cancel_key());
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ async fn proxy_mitm(
|
||||
let (end_client, startup) = match handshake(
|
||||
&RequestContext::test(),
|
||||
client1,
|
||||
TaskTracker::new().token(),
|
||||
Some(&server_config1),
|
||||
false,
|
||||
)
|
||||
@@ -45,7 +46,7 @@ async fn proxy_mitm(
|
||||
.unwrap()
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
|
||||
@@ -15,6 +15,7 @@ use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
@@ -178,10 +179,12 @@ async fn dummy_proxy(
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
let t = TaskTracker::new().token();
|
||||
let mut stream =
|
||||
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
@@ -622,7 +625,7 @@ async fn connect_to_compute_success() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -636,7 +639,7 @@ async fn connect_to_compute_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -651,7 +654,7 @@ async fn connect_to_compute_non_retry_1() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -666,7 +669,7 @@ async fn connect_to_compute_non_retry_2() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -691,7 +694,7 @@ async fn connect_to_compute_non_retry_3() {
|
||||
connect_to_compute(
|
||||
&ctx,
|
||||
&mechanism,
|
||||
&user_info,
|
||||
user_info,
|
||||
wake_compute_retry_config,
|
||||
&config,
|
||||
)
|
||||
@@ -709,7 +712,7 @@ async fn wake_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -724,7 +727,7 @@ async fn wake_non_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -743,7 +746,7 @@ async fn fail_but_wake_invalidates_cache() {
|
||||
let user = helper_create_connect_info(&mech);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -764,7 +767,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
|
||||
let user = helper_create_connect_info(&mech);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -785,7 +788,7 @@ async fn retry_but_wake_invalidates_cache() {
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -808,7 +811,7 @@ async fn retry_no_wake_skips_invalidation() {
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -224,13 +224,13 @@ impl PoolingBackend {
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
@@ -268,13 +268,13 @@ impl PoolingBackend {
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
HyperMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, info, warn};
|
||||
|
||||
use crate::cancellation::CancellationHandler;
|
||||
@@ -124,7 +124,6 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
@@ -150,11 +149,11 @@ pub async fn task_main(
|
||||
let conn_token = cancellation_token.child_token();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let backend = backend.clone();
|
||||
let connections2 = connections.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellations = cancellations.clone();
|
||||
connections.spawn(
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
||||
@@ -181,8 +180,7 @@ pub async fn task_main(
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
backend,
|
||||
connections2,
|
||||
cancellations,
|
||||
tracker,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
@@ -305,8 +303,7 @@ async fn connection_startup(
|
||||
async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellations: TaskTracker,
|
||||
tracker: TaskTrackerToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -347,19 +344,17 @@ async fn connection_handler(
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let cancellations = cancellations.clone();
|
||||
let handler = connections.spawn(
|
||||
let handler = tokio::spawn(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
tracker.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellations,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -400,14 +395,13 @@ async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
tracker: TaskTrackerToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
session_id: uuid::Uuid,
|
||||
conn_info: ConnectionInfo,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellations: TaskTracker,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -441,10 +435,17 @@ async fn request_handler(
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
let cancellations = cancellations.clone();
|
||||
ws_connections.spawn(
|
||||
tokio::spawn(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
let websocket = match websocket.await {
|
||||
Err(e) => {
|
||||
warn!("could not upgrade websocket connection: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok(websocket) => websocket,
|
||||
};
|
||||
|
||||
websocket::serve_websocket(
|
||||
config,
|
||||
backend.auth_backend,
|
||||
ctx,
|
||||
@@ -452,12 +453,9 @@ async fn request_handler(
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
cancellations,
|
||||
tracker,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("error in websocket connection: {e:#}");
|
||||
}
|
||||
.await;
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
@@ -2,14 +2,14 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use framed_websockets::{Frame, OpCode, WebSocketServer};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::OnUpgrade;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::cancellation::CancellationHandler;
|
||||
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
||||
use crate::proxy::{ClientMode, handle_client};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
|
||||
pin_project! {
|
||||
@@ -128,13 +128,12 @@ pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
ctx: RequestContext,
|
||||
websocket: OnUpgrade,
|
||||
websocket: Upgraded,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
tracker: TaskTrackerToken,
|
||||
) {
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
@@ -142,36 +141,28 @@ pub(crate) async fn serve_websocket(
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Ws);
|
||||
|
||||
let res = Box::pin(handle_client(
|
||||
let mut ctx_slot = Some(ctx);
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
&mut ctx_slot,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
))
|
||||
tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
match (ctx_slot, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Err(e)) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
Err(e.into())
|
||||
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.set_success();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
@@ -24,19 +25,22 @@ use crate::tls::TlsServerEndPoint;
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
pub(crate) framed: Framed<S>,
|
||||
pub(crate) tracker: TaskTrackerToken,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
|
||||
Self {
|
||||
framed: Framed::new(stream),
|
||||
tracker,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
self.framed.into_inner()
|
||||
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
|
||||
let (stream, read) = self.framed.into_inner();
|
||||
(stream, read, self.tracker)
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
|
||||
Reference in New Issue
Block a user