optimise proxy_pass memory size a little, also boxing requestcontext since it is large

This commit is contained in:
Conrad Ludgate
2025-05-29 17:52:26 +01:00
parent 0633cd6385
commit 219c72c24c
6 changed files with 52 additions and 76 deletions

View File

@@ -16,9 +16,7 @@ use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
pub async fn task_main(
config: &'static ProxyConfig,
@@ -126,22 +124,7 @@ pub async fn task_main(
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
}
}
});

View File

@@ -38,7 +38,7 @@ pub struct RequestContext(
/// I would typically use a RefCell but that would break the `Send` requirements
/// so we need something with thread-safety. `TryLock` is a cheap alternative
/// that offers similar semantics to a `RefCell` but with synchronisation.
TryLock<RequestContextInner>,
TryLock<Box<RequestContextInner>>,
);
struct RequestContextInner {
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
impl Clone for RequestContext {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = RequestContextInner {
let new = Box::new(RequestContextInner {
conn_info: inner.conn_info.clone(),
session_id: inner.session_id,
protocol: inner.protocol,
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
disconnect_sender: None,
latency_timer: LatencyTimer::noop(inner.protocol),
disconnect_timestamp: inner.disconnect_timestamp,
};
});
Self(TryLock::new(new))
}
@@ -140,7 +140,7 @@ impl RequestContext {
role = tracing::field::Empty,
);
let inner = RequestContextInner {
let inner = Box::new(RequestContextInner {
conn_info,
session_id,
protocol,
@@ -168,7 +168,7 @@ impl RequestContext {
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
disconnect_timestamp: None,
};
});
Self(TryLock::new(inner))
}
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
}
}
pub struct DisconnectLogger(RequestContextInner);
pub struct DisconnectLogger(Box<RequestContextInner>);
impl Drop for DisconnectLogger {
fn drop(&mut self) {

View File

@@ -154,42 +154,19 @@ pub async fn task_main(
.boxed()
.await;
let passthrough = match res {
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
return;
}
Ok(None) => {
ctx.set_success();
return;
}
Ok(Some(p)) => {
ctx.set_success();
p
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
}
};
// spawn passthrough as a new task.
// holds a task tracker token to prevent the proxy shutting down early.
tokio::spawn(async move {
let _disconnect = ctx.log_connect();
match passthrough.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
});
}
});
}

View File

@@ -8,6 +8,7 @@ use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
@@ -75,11 +76,13 @@ pub(crate) struct ProxyPassthrough<S> {
pub(crate) _tracker: TaskTrackerToken,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
ctx: RequestContext,
compute_config: &'static ComputeConfig,
) {
let _disconnect = ctx.log_connect();
let res = proxy_pass(
self.client,
self.compute.stream,
@@ -87,6 +90,23 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
self.private_link_id,
)
.await;
match res {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
tracing::warn!(
session_id = ?self.session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
tracing::error!(
session_id = ?self.session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
if let Err(err) = self
.compute
.cancel_closure
@@ -96,8 +116,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
res
// we don't need a result. If the queue is full, we just log the error
drop(self.cancel.remove_cancel_key());
}
}

View File

@@ -437,7 +437,15 @@ async fn request_handler(
tokio::spawn(
async move {
if let Err(e) = websocket::serve_websocket(
let websocket = match websocket.await {
Err(e) => {
warn!("could not upgrade websocket connection: {e:#}");
return;
}
Ok(websocket) => websocket,
};
websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
@@ -447,10 +455,7 @@ async fn request_handler(
host,
tracker,
)
.await
{
warn!("error in websocket connection: {e:#}");
}
.await;
}
.instrument(span),
);

View File

@@ -2,11 +2,10 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::OnUpgrade;
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
@@ -18,7 +17,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::proxy::{ClientMode, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
@@ -129,13 +128,12 @@ pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ctx: RequestContext,
websocket: OnUpgrade,
websocket: Upgraded,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
tracker: TaskTrackerToken,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
) {
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
@@ -159,20 +157,14 @@ pub(crate) async fn serve_websocket(
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
Err(e.into())
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
}
}
}