mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-26 09:30:37 +00:00
optimise passthrough calling convention to further reduce memory
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures::TryFutureExt;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
@@ -15,7 +15,7 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
use crate::proxy::passthrough::passthrough;
|
||||
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
|
||||
|
||||
pub async fn task_main(
|
||||
@@ -101,30 +101,28 @@ pub async fn task_main(
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let span = ctx.span();
|
||||
let mut slot = Some(ctx);
|
||||
let res = handle_client(
|
||||
config,
|
||||
backend,
|
||||
&ctx,
|
||||
&mut slot,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.instrument(span)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
match (slot, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.success();
|
||||
}
|
||||
(Some(ctx), Err(e)) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
|
||||
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -140,40 +138,39 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx: &RequestContext,
|
||||
ctx_slot: &mut Option<RequestContext>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
) -> Result<(), ClientRequestError> {
|
||||
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
|
||||
debug!(%protocol, "handling interactive connection from client");
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
let request_gauge = metrics.connection_requests.guard(protocol);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
|
||||
let data = {
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
|
||||
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
|
||||
};
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
let (mut stream, params) = match data {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
tokio::spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
@@ -193,11 +190,11 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, user_info, _ip_allowlist) = match backend
|
||||
@@ -239,14 +236,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_tracker: tracker,
|
||||
}))
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
tokio::spawn(passthrough(
|
||||
ctx,
|
||||
&config.connect_to_compute,
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -13,10 +13,11 @@ pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use passthrough::passthrough;
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -24,13 +25,12 @@ use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use self::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -412,23 +412,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let p = ProxyPassthrough {
|
||||
client: stream,
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_tracker: tracker,
|
||||
};
|
||||
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
|
||||
tokio::spawn(passthrough(
|
||||
ctx,
|
||||
&config.connect_to_compute,
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use smol_str::SmolStr;
|
||||
use smol_str::{SmolStr, ToSmolStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
@@ -11,6 +11,7 @@ use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
@@ -62,60 +63,53 @@ pub(crate) async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
ctx: RequestContext,
|
||||
compute_config: &'static ComputeConfig,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
/// ensures proxy stays online while this is set.
|
||||
pub(crate) _tracker: TaskTrackerToken,
|
||||
}
|
||||
client: Stream<S>,
|
||||
compute: PostgresConnection,
|
||||
cancel: cancellation::Session,
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
ctx: RequestContext,
|
||||
compute_config: &'static ComputeConfig,
|
||||
) {
|
||||
let _disconnect = ctx.log_connect();
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.compute.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
_req: NumConnectionRequestsGuard<'static>,
|
||||
_conn: NumClientConnectionsGuard<'static>,
|
||||
_tracker: TaskTrackerToken,
|
||||
) {
|
||||
let session_id = ctx.session_id();
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
tracing::warn!(
|
||||
session_id = ?self.session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
tracing::error!(
|
||||
session_id = ?self.session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
let _disconnect = ctx.log_connect();
|
||||
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
|
||||
|
||||
match res {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
tracing::warn!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
tracing::error!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
|
||||
// we don't need a result. If the queue is full, we just log the error
|
||||
drop(self.cancel.remove_cancel_key());
|
||||
}
|
||||
|
||||
if let Err(err) = compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
// we don't need a result. If the queue is full, we just log the error
|
||||
drop(cancel.remove_cancel_key());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user