mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 01:42:55 +00:00
manually handle task tracker tokens
This commit is contained in:
@@ -547,6 +547,7 @@ mod tests {
|
||||
use postgres_protocol::message::backend::Message as PgMessage;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use super::jwt::JwkCache;
|
||||
use super::{AuthRateLimiter, auth_quirks};
|
||||
@@ -697,7 +698,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -779,7 +780,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -833,7 +834,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
|
||||
@@ -18,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
@@ -226,7 +227,8 @@ pub(super) async fn task_main(
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
let compute_tls_config = compute_tls_config.clone();
|
||||
|
||||
connections.spawn(
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(
|
||||
async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
@@ -249,6 +251,7 @@ pub(super) async fn task_main(
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
tracker,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -274,10 +277,11 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
raw_stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::SslRequest;
|
||||
@@ -291,7 +295,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
let (raw, read_buf, tracker) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empty.
|
||||
// However, you could imagine pipelining of postgres
|
||||
@@ -302,13 +306,16 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
Ok((
|
||||
Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
},
|
||||
tracker,
|
||||
))
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
@@ -329,8 +336,10 @@ async fn handle_client(
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
let (mut tls_stream, _tracker) =
|
||||
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
@@ -35,7 +36,6 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -49,11 +49,11 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
connections.spawn(async move {
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
@@ -110,7 +110,7 @@ pub async fn task_main(
|
||||
cancellation_handler,
|
||||
socket,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
tracker,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
@@ -148,12 +148,10 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -166,7 +164,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
@@ -182,20 +180,21 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
|
||||
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
tokio::spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
let _tracker = tracker;
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
@@ -205,8 +204,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
|
||||
.ok();
|
||||
}
|
||||
.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
@@ -252,7 +253,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
@@ -264,5 +265,6 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_tracker: tracker,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use pq_proto::{
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::endpoint_sni;
|
||||
@@ -51,7 +52,7 @@ impl ReportableError for HandshakeError {
|
||||
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
Cancel(CancelKeyData, TaskTrackerToken),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
@@ -62,6 +63,7 @@ pub(crate) enum HandshakeData<S> {
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<HandshakeData<S>, HandshakeError> {
|
||||
@@ -71,7 +73,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
match msg {
|
||||
@@ -157,15 +159,13 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
stream.framed = Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use self::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
@@ -70,7 +71,6 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -84,12 +84,12 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
let (socket, conn_info) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
@@ -148,7 +148,7 @@ pub async fn task_main(
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
tracker,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
@@ -186,12 +186,10 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -267,7 +265,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
@@ -283,20 +281,29 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
let do_handshake = handshake(
|
||||
ctx,
|
||||
stream,
|
||||
tracker,
|
||||
mode.handshake_tls(tls),
|
||||
record_handshake_error,
|
||||
);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
tokio::spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
// ensure the proxy doesn't shutdown until we complete this task.
|
||||
let _tracker = tracker;
|
||||
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
@@ -306,8 +313,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
|
||||
.ok();
|
||||
}
|
||||
.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
@@ -391,7 +400,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
@@ -409,6 +418,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_tracker: tracker,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
@@ -71,6 +72,8 @@ pub(crate) struct ProxyPassthrough<S> {
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
/// ensures proxy stays online while this is set.
|
||||
pub(crate) _tracker: TaskTrackerToken,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
|
||||
@@ -38,6 +38,7 @@ async fn proxy_mitm(
|
||||
let (end_client, startup) = match handshake(
|
||||
&RequestContext::test(),
|
||||
client1,
|
||||
TaskTracker::new().token(),
|
||||
Some(&server_config1),
|
||||
false,
|
||||
)
|
||||
@@ -45,7 +46,7 @@ async fn proxy_mitm(
|
||||
.unwrap()
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
|
||||
@@ -15,6 +15,7 @@ use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
@@ -178,10 +179,12 @@ async fn dummy_proxy(
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
let t = TaskTracker::new().token();
|
||||
let mut stream =
|
||||
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
|
||||
@@ -124,7 +124,6 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
@@ -153,7 +152,6 @@ pub async fn task_main(
|
||||
let connections2 = connections.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellations = cancellations.clone();
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
@@ -182,7 +180,6 @@ pub async fn task_main(
|
||||
config,
|
||||
backend,
|
||||
connections2,
|
||||
cancellations,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
@@ -306,7 +303,6 @@ async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellations: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -347,7 +343,6 @@ async fn connection_handler(
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let cancellations = cancellations.clone();
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
@@ -359,7 +354,6 @@ async fn connection_handler(
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellations,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -407,7 +401,6 @@ async fn request_handler(
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellations: TaskTracker,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -441,8 +434,8 @@ async fn request_handler(
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
let cancellations = cancellations.clone();
|
||||
ws_connections.spawn(
|
||||
let tracker = ws_connections.token();
|
||||
tokio::spawn(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
@@ -452,7 +445,7 @@ async fn request_handler(
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
cancellations,
|
||||
tracker,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -10,6 +10,7 @@ use hyper::upgrade::OnUpgrade;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::cancellation::CancellationHandler;
|
||||
@@ -132,7 +133,7 @@ pub(crate) async fn serve_websocket(
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
@@ -151,7 +152,7 @@ pub(crate) async fn serve_websocket(
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
tracker,
|
||||
))
|
||||
.await;
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
@@ -24,19 +25,22 @@ use crate::tls::TlsServerEndPoint;
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
pub(crate) framed: Framed<S>,
|
||||
pub(crate) tracker: TaskTrackerToken,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
|
||||
Self {
|
||||
framed: Framed::new(stream),
|
||||
tracker,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
self.framed.into_inner()
|
||||
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
|
||||
let (stream, read) = self.framed.into_inner();
|
||||
(stream, read, self.tracker)
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
|
||||
Reference in New Issue
Block a user