mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-23 06:09:59 +00:00
pageserver: cancellation handling in writes to postgres client socket (#5503)
## Problem Writes to the postgres client socket from the page server were not wrapped in cancellation handling, so a stuck client connection could prevent tenant shutdowwn. ## Summary of changes All the places we call flush() to write to the socket, we should be respecting the cancellation token for the task. In this PR, I explicitly pass around a CancellationToken rather than doing inline `task_mgr::shutdown_token` calls, to avoid coupling it to the global task_mgr state and make it easier to refactor later. I have some follow-on commits that add a Shutdown variant to QueryError and use it more extensively, but that's pure refactor so will keep separate from this bug fix PR. Closes: https://github.com/neondatabase/neon/issues/5341
This commit is contained in:
@@ -442,10 +442,20 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
trace!("got message {:?}", msg);
|
||||
|
||||
let result = self.process_message(handler, msg, &mut query_string).await;
|
||||
self.flush().await?;
|
||||
tokio::select!(
|
||||
biased;
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received during response flush");
|
||||
return Ok(())
|
||||
},
|
||||
flush_r = self.flush() => {
|
||||
flush_r?;
|
||||
}
|
||||
);
|
||||
|
||||
match result? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => break,
|
||||
|
||||
@@ -35,6 +35,7 @@ use std::time::Duration;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::io::StreamReader;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::field;
|
||||
use tracing::*;
|
||||
use utils::id::ConnectionId;
|
||||
@@ -64,69 +65,6 @@ use crate::trace::Tracer;
|
||||
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
|
||||
use postgres_ffi::BLCKSZ;
|
||||
|
||||
fn copyin_stream<IO>(pgb: &mut PostgresBackend<IO>) -> impl Stream<Item = io::Result<Bytes>> + '_
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
async_stream::try_stream! {
|
||||
loop {
|
||||
let msg = tokio::select! {
|
||||
biased;
|
||||
|
||||
_ = task_mgr::shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
let msg = "pageserver is shutting down";
|
||||
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, None));
|
||||
Err(QueryError::Other(anyhow::anyhow!(msg)))
|
||||
}
|
||||
|
||||
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
|
||||
};
|
||||
|
||||
match msg {
|
||||
Ok(Some(message)) => {
|
||||
let copy_data_bytes = match message {
|
||||
FeMessage::CopyData(bytes) => bytes,
|
||||
FeMessage::CopyDone => { break },
|
||||
FeMessage::Sync => continue,
|
||||
FeMessage::Terminate => {
|
||||
let msg = "client terminated connection with Terminate message during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
break;
|
||||
}
|
||||
m => {
|
||||
let msg = format!("unexpected message {m:?}");
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
|
||||
Err(io::Error::new(io::ErrorKind::Other, msg))?;
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
yield copy_data_bytes;
|
||||
}
|
||||
Ok(None) => {
|
||||
let msg = "client closed connection during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
pgb.flush().await?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
Err(io_error)?;
|
||||
}
|
||||
Err(other) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read the end of a tar archive.
|
||||
///
|
||||
/// A tar archive normally ends with two consecutive blocks of zeros, 512 bytes each.
|
||||
@@ -284,7 +222,13 @@ async fn page_service_conn_main(
|
||||
// and create a child per-query context when it invokes process_query.
|
||||
// But it's in a shared crate, so, we store connection_ctx inside PageServerHandler
|
||||
// and create the per-query context in process_query ourselves.
|
||||
let mut conn_handler = PageServerHandler::new(conf, broker_client, auth, connection_ctx);
|
||||
let mut conn_handler = PageServerHandler::new(
|
||||
conf,
|
||||
broker_client,
|
||||
auth,
|
||||
connection_ctx,
|
||||
task_mgr::shutdown_token(),
|
||||
);
|
||||
let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?;
|
||||
|
||||
match pgbackend
|
||||
@@ -318,6 +262,10 @@ struct PageServerHandler {
|
||||
/// For each query received over the connection,
|
||||
/// `process_query` creates a child context from this one.
|
||||
connection_ctx: RequestContext,
|
||||
|
||||
/// A token that should fire when the tenant transitions from
|
||||
/// attached state, or when the pageserver is shutting down.
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
impl PageServerHandler {
|
||||
@@ -326,6 +274,7 @@ impl PageServerHandler {
|
||||
broker_client: storage_broker::BrokerClientChannel,
|
||||
auth: Option<Arc<JwtAuth>>,
|
||||
connection_ctx: RequestContext,
|
||||
cancel: CancellationToken,
|
||||
) -> Self {
|
||||
PageServerHandler {
|
||||
_conf: conf,
|
||||
@@ -333,6 +282,91 @@ impl PageServerHandler {
|
||||
auth,
|
||||
claims: None,
|
||||
connection_ctx,
|
||||
cancel,
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap PostgresBackend::flush to respect our CancellationToken: it is important to use
|
||||
/// this rather than naked flush() in order to shut down promptly. Without this, we would
|
||||
/// block shutdown of a tenant if a postgres client was failing to consume bytes we send
|
||||
/// in the flush.
|
||||
async fn flush_cancellable<IO>(&self, pgb: &mut PostgresBackend<IO>) -> Result<(), QueryError>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin,
|
||||
{
|
||||
tokio::select!(
|
||||
flush_r = pgb.flush() => {
|
||||
Ok(flush_r?)
|
||||
},
|
||||
_ = self.cancel.cancelled() => {
|
||||
Err(QueryError::Other(anyhow::anyhow!("Shutting down")))
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn copyin_stream<'a, IO>(
|
||||
&'a self,
|
||||
pgb: &'a mut PostgresBackend<IO>,
|
||||
) -> impl Stream<Item = io::Result<Bytes>> + 'a
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin,
|
||||
{
|
||||
async_stream::try_stream! {
|
||||
loop {
|
||||
let msg = tokio::select! {
|
||||
biased;
|
||||
|
||||
_ = task_mgr::shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
let msg = "pageserver is shutting down";
|
||||
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, None));
|
||||
Err(QueryError::Other(anyhow::anyhow!(msg)))
|
||||
}
|
||||
|
||||
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
|
||||
};
|
||||
|
||||
match msg {
|
||||
Ok(Some(message)) => {
|
||||
let copy_data_bytes = match message {
|
||||
FeMessage::CopyData(bytes) => bytes,
|
||||
FeMessage::CopyDone => { break },
|
||||
FeMessage::Sync => continue,
|
||||
FeMessage::Terminate => {
|
||||
let msg = "client terminated connection with Terminate message during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
break;
|
||||
}
|
||||
m => {
|
||||
let msg = format!("unexpected message {m:?}");
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
|
||||
Err(io::Error::new(io::ErrorKind::Other, msg))?;
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
yield copy_data_bytes;
|
||||
}
|
||||
Ok(None) => {
|
||||
let msg = "client closed connection during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
self.flush_cancellable(pgb).await.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
Err(io_error)?;
|
||||
}
|
||||
Err(other) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,7 +406,7 @@ impl PageServerHandler {
|
||||
|
||||
// switch client to COPYBOTH
|
||||
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
|
||||
pgb.flush().await?;
|
||||
self.flush_cancellable(pgb).await?;
|
||||
|
||||
let metrics = metrics::SmgrQueryTimePerTimeline::new(&tenant_id, &timeline_id);
|
||||
|
||||
@@ -465,7 +499,7 @@ impl PageServerHandler {
|
||||
});
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
|
||||
pgb.flush().await?;
|
||||
self.flush_cancellable(pgb).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -508,9 +542,9 @@ impl PageServerHandler {
|
||||
// Import basebackup provided via CopyData
|
||||
info!("importing basebackup");
|
||||
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
|
||||
pgb.flush().await?;
|
||||
self.flush_cancellable(pgb).await?;
|
||||
|
||||
let mut copyin_reader = pin!(StreamReader::new(copyin_stream(pgb)));
|
||||
let mut copyin_reader = pin!(StreamReader::new(self.copyin_stream(pgb)));
|
||||
timeline
|
||||
.import_basebackup_from_tar(
|
||||
&mut copyin_reader,
|
||||
@@ -563,8 +597,8 @@ impl PageServerHandler {
|
||||
// Import wal provided via CopyData
|
||||
info!("importing wal");
|
||||
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
|
||||
pgb.flush().await?;
|
||||
let mut copyin_reader = pin!(StreamReader::new(copyin_stream(pgb)));
|
||||
self.flush_cancellable(pgb).await?;
|
||||
let mut copyin_reader = pin!(StreamReader::new(self.copyin_stream(pgb)));
|
||||
import_wal_from_tar(&timeline, &mut copyin_reader, start_lsn, end_lsn, &ctx).await?;
|
||||
info!("wal import complete");
|
||||
|
||||
@@ -772,7 +806,7 @@ impl PageServerHandler {
|
||||
|
||||
// switch client to COPYOUT
|
||||
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
|
||||
pgb.flush().await?;
|
||||
self.flush_cancellable(pgb).await?;
|
||||
|
||||
// Send a tarball of the latest layer on the timeline. Compress if not
|
||||
// fullbackup. TODO Compress in that case too (tests need to be updated)
|
||||
@@ -824,7 +858,7 @@ impl PageServerHandler {
|
||||
}
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::CopyDone)?;
|
||||
pgb.flush().await?;
|
||||
self.flush_cancellable(pgb).await?;
|
||||
|
||||
let basebackup_after = started
|
||||
.elapsed()
|
||||
|
||||
@@ -17,6 +17,8 @@ def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgB
|
||||
n_restarts = 10
|
||||
scale = 10
|
||||
|
||||
env.pageserver.allowed_errors.append(".*query handler.*failed.*Shutting down")
|
||||
|
||||
def run_pgbench(connstr: str):
|
||||
log.info(f"Start a pgbench workload on pg {connstr}")
|
||||
pg_bin.run_capture(["pgbench", "-i", f"-s{scale}", connstr])
|
||||
|
||||
@@ -752,6 +752,9 @@ def test_ignore_while_attaching(
|
||||
env.pageserver.allowed_errors.append(
|
||||
f".*Tenant {tenant_id} will not become active\\. Current state: Stopping.*"
|
||||
)
|
||||
# An endpoint is starting up concurrently with our detach, it can
|
||||
# experience RPC failure due to shutdown.
|
||||
env.pageserver.allowed_errors.append(".*query handler.*failed.*Shutting down")
|
||||
|
||||
data_id = 1
|
||||
data_secret = "very secret secret"
|
||||
|
||||
Reference in New Issue
Block a user