diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index 453c58431a..08c4e03d13 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -442,10 +442,20 @@ impl PostgresBackend { trace!("got message {:?}", msg); let result = self.process_message(handler, msg, &mut query_string).await; - self.flush().await?; + tokio::select!( + biased; + _ = shutdown_watcher() => { + // We were requested to shut down. + tracing::info!("shutdown request received during response flush"); + return Ok(()) + }, + flush_r = self.flush() => { + flush_r?; + } + ); + match result? { ProcessMsgResult::Continue => { - self.flush().await?; continue; } ProcessMsgResult::Break => break, diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index decc82112d..5ab4fbbd4c 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -35,6 +35,7 @@ use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::io::StreamReader; +use tokio_util::sync::CancellationToken; use tracing::field; use tracing::*; use utils::id::ConnectionId; @@ -64,69 +65,6 @@ use crate::trace::Tracer; use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; use postgres_ffi::BLCKSZ; -fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream> + '_ -where - IO: AsyncRead + AsyncWrite + Unpin, -{ - async_stream::try_stream! { - loop { - let msg = tokio::select! { - biased; - - _ = task_mgr::shutdown_watcher() => { - // We were requested to shut down. - let msg = "pageserver is shutting down"; - let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, None)); - Err(QueryError::Other(anyhow::anyhow!(msg))) - } - - msg = pgb.read_message() => { msg.map_err(QueryError::from)} - }; - - match msg { - Ok(Some(message)) => { - let copy_data_bytes = match message { - FeMessage::CopyData(bytes) => bytes, - FeMessage::CopyDone => { break }, - FeMessage::Sync => continue, - FeMessage::Terminate => { - let msg = "client terminated connection with Terminate message during COPY"; - let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - // error can't happen here, ErrorResponse serialization should be always ok - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; - Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; - break; - } - m => { - let msg = format!("unexpected message {m:?}"); - // error can't happen here, ErrorResponse serialization should be always ok - pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?; - Err(io::Error::new(io::ErrorKind::Other, msg))?; - break; - } - }; - - yield copy_data_bytes; - } - Ok(None) => { - let msg = "client closed connection during COPY"; - let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - // error can't happen here, ErrorResponse serialization should be always ok - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; - pgb.flush().await?; - Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; - } - Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { - Err(io_error)?; - } - Err(other) => { - Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?; - } - }; - } - } -} - /// Read the end of a tar archive. /// /// A tar archive normally ends with two consecutive blocks of zeros, 512 bytes each. @@ -284,7 +222,13 @@ async fn page_service_conn_main( // and create a child per-query context when it invokes process_query. // But it's in a shared crate, so, we store connection_ctx inside PageServerHandler // and create the per-query context in process_query ourselves. - let mut conn_handler = PageServerHandler::new(conf, broker_client, auth, connection_ctx); + let mut conn_handler = PageServerHandler::new( + conf, + broker_client, + auth, + connection_ctx, + task_mgr::shutdown_token(), + ); let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?; match pgbackend @@ -318,6 +262,10 @@ struct PageServerHandler { /// For each query received over the connection, /// `process_query` creates a child context from this one. connection_ctx: RequestContext, + + /// A token that should fire when the tenant transitions from + /// attached state, or when the pageserver is shutting down. + cancel: CancellationToken, } impl PageServerHandler { @@ -326,6 +274,7 @@ impl PageServerHandler { broker_client: storage_broker::BrokerClientChannel, auth: Option>, connection_ctx: RequestContext, + cancel: CancellationToken, ) -> Self { PageServerHandler { _conf: conf, @@ -333,6 +282,91 @@ impl PageServerHandler { auth, claims: None, connection_ctx, + cancel, + } + } + + /// Wrap PostgresBackend::flush to respect our CancellationToken: it is important to use + /// this rather than naked flush() in order to shut down promptly. Without this, we would + /// block shutdown of a tenant if a postgres client was failing to consume bytes we send + /// in the flush. + async fn flush_cancellable(&self, pgb: &mut PostgresBackend) -> Result<(), QueryError> + where + IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, + { + tokio::select!( + flush_r = pgb.flush() => { + Ok(flush_r?) + }, + _ = self.cancel.cancelled() => { + Err(QueryError::Other(anyhow::anyhow!("Shutting down"))) + } + ) + } + + fn copyin_stream<'a, IO>( + &'a self, + pgb: &'a mut PostgresBackend, + ) -> impl Stream> + 'a + where + IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, + { + async_stream::try_stream! { + loop { + let msg = tokio::select! { + biased; + + _ = task_mgr::shutdown_watcher() => { + // We were requested to shut down. + let msg = "pageserver is shutting down"; + let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, None)); + Err(QueryError::Other(anyhow::anyhow!(msg))) + } + + msg = pgb.read_message() => { msg.map_err(QueryError::from)} + }; + + match msg { + Ok(Some(message)) => { + let copy_data_bytes = match message { + FeMessage::CopyData(bytes) => bytes, + FeMessage::CopyDone => { break }, + FeMessage::Sync => continue, + FeMessage::Terminate => { + let msg = "client terminated connection with Terminate message during COPY"; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; + Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; + break; + } + m => { + let msg = format!("unexpected message {m:?}"); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?; + Err(io::Error::new(io::ErrorKind::Other, msg))?; + break; + } + }; + + yield copy_data_bytes; + } + Ok(None) => { + let msg = "client closed connection during COPY"; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; + self.flush_cancellable(pgb).await.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; + Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; + } + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { + Err(io_error)?; + } + Err(other) => { + Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?; + } + }; + } } } @@ -372,7 +406,7 @@ impl PageServerHandler { // switch client to COPYBOTH pgb.write_message_noflush(&BeMessage::CopyBothResponse)?; - pgb.flush().await?; + self.flush_cancellable(pgb).await?; let metrics = metrics::SmgrQueryTimePerTimeline::new(&tenant_id, &timeline_id); @@ -465,7 +499,7 @@ impl PageServerHandler { }); pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?; - pgb.flush().await?; + self.flush_cancellable(pgb).await?; } Ok(()) } @@ -508,9 +542,9 @@ impl PageServerHandler { // Import basebackup provided via CopyData info!("importing basebackup"); pgb.write_message_noflush(&BeMessage::CopyInResponse)?; - pgb.flush().await?; + self.flush_cancellable(pgb).await?; - let mut copyin_reader = pin!(StreamReader::new(copyin_stream(pgb))); + let mut copyin_reader = pin!(StreamReader::new(self.copyin_stream(pgb))); timeline .import_basebackup_from_tar( &mut copyin_reader, @@ -563,8 +597,8 @@ impl PageServerHandler { // Import wal provided via CopyData info!("importing wal"); pgb.write_message_noflush(&BeMessage::CopyInResponse)?; - pgb.flush().await?; - let mut copyin_reader = pin!(StreamReader::new(copyin_stream(pgb))); + self.flush_cancellable(pgb).await?; + let mut copyin_reader = pin!(StreamReader::new(self.copyin_stream(pgb))); import_wal_from_tar(&timeline, &mut copyin_reader, start_lsn, end_lsn, &ctx).await?; info!("wal import complete"); @@ -772,7 +806,7 @@ impl PageServerHandler { // switch client to COPYOUT pgb.write_message_noflush(&BeMessage::CopyOutResponse)?; - pgb.flush().await?; + self.flush_cancellable(pgb).await?; // Send a tarball of the latest layer on the timeline. Compress if not // fullbackup. TODO Compress in that case too (tests need to be updated) @@ -824,7 +858,7 @@ impl PageServerHandler { } pgb.write_message_noflush(&BeMessage::CopyDone)?; - pgb.flush().await?; + self.flush_cancellable(pgb).await?; let basebackup_after = started .elapsed() diff --git a/test_runner/regress/test_pageserver_restarts_under_workload.py b/test_runner/regress/test_pageserver_restarts_under_workload.py index 65569f3bac..71058268a6 100644 --- a/test_runner/regress/test_pageserver_restarts_under_workload.py +++ b/test_runner/regress/test_pageserver_restarts_under_workload.py @@ -17,6 +17,8 @@ def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgB n_restarts = 10 scale = 10 + env.pageserver.allowed_errors.append(".*query handler.*failed.*Shutting down") + def run_pgbench(connstr: str): log.info(f"Start a pgbench workload on pg {connstr}") pg_bin.run_capture(["pgbench", "-i", f"-s{scale}", connstr]) diff --git a/test_runner/regress/test_tenant_detach.py b/test_runner/regress/test_tenant_detach.py index a20523b1f3..519af1cbde 100644 --- a/test_runner/regress/test_tenant_detach.py +++ b/test_runner/regress/test_tenant_detach.py @@ -752,6 +752,9 @@ def test_ignore_while_attaching( env.pageserver.allowed_errors.append( f".*Tenant {tenant_id} will not become active\\. Current state: Stopping.*" ) + # An endpoint is starting up concurrently with our detach, it can + # experience RPC failure due to shutdown. + env.pageserver.allowed_errors.append(".*query handler.*failed.*Shutting down") data_id = 1 data_secret = "very secret secret"