diff --git a/Cargo.lock b/Cargo.lock index c5b64b235a..5b99e93e76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3367,6 +3367,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-io-timeout", "tokio-postgres", "toml_edit", "tracing", diff --git a/Cargo.toml b/Cargo.toml index d563324c86..679605dc1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,6 +101,7 @@ test-context = "0.1" thiserror = "1.0" tls-listener = { version = "0.6", features = ["rustls", "hyper-h1"] } tokio = { version = "1.17", features = ["macros"] } +tokio-io-timeout = "1.2.0" tokio-postgres-rustls = "0.9.0" tokio-rustls = "0.23" tokio-stream = "0.1" diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index 60932a5950..f6bf7c6fc2 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -54,7 +54,7 @@ pub fn is_expected_io_error(e: &io::Error) -> bool { use io::ErrorKind::*; matches!( e.kind(), - ConnectionRefused | ConnectionAborted | ConnectionReset + ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut ) } @@ -320,9 +320,17 @@ impl PostgresBackend { if let ProtoState::Closed = self.state { Ok(None) } else { - let m = self.framed.read_message().await?; - trace!("read msg {:?}", m); - Ok(m) + match self.framed.read_message().await { + Ok(m) => { + trace!("read msg {:?}", m); + Ok(m) + } + Err(e) => { + // remember not to try to read anymore + self.state = ProtoState::Closed; + Err(e) + } + } } } @@ -493,7 +501,10 @@ impl PostgresBackend { MaybeWriteOnly::Full(framed) => { let (reader, writer) = framed.split(); self.framed = MaybeWriteOnly::WriteOnly(writer); - Ok(PostgresBackendReader(reader)) + Ok(PostgresBackendReader { + reader, + closed: false, + }) } MaybeWriteOnly::WriteOnly(_) => { anyhow::bail!("PostgresBackend is already split") @@ -510,8 +521,12 @@ impl PostgresBackend { anyhow::bail!("PostgresBackend is not split") } MaybeWriteOnly::WriteOnly(writer) => { - let joined = Framed::unsplit(reader.0, writer); + let joined = Framed::unsplit(reader.reader, writer); self.framed = MaybeWriteOnly::Full(joined); + // if reader encountered connection error, do not attempt reading anymore + if reader.closed { + self.state = ProtoState::Closed; + } Ok(()) } MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"), @@ -797,15 +812,25 @@ impl PostgresBackend { } } -pub struct PostgresBackendReader(FramedReader>); +pub struct PostgresBackendReader { + reader: FramedReader>, + closed: bool, // true if received error closing the connection +} impl PostgresBackendReader { /// Read full message or return None if connection is cleanly closed with no /// unprocessed data. pub async fn read_message(&mut self) -> Result, ConnectionError> { - let m = self.0.read_message().await?; - trace!("read msg {:?}", m); - Ok(m) + match self.reader.read_message().await { + Ok(m) => { + trace!("read msg {:?}", m); + Ok(m) + } + Err(e) => { + self.closed = true; + Err(e) + } + } } /// Get CopyData contents of the next message in COPY stream or error @@ -923,7 +948,7 @@ pub enum CopyStreamHandlerEnd { #[error("EOF on COPY stream")] EOF, /// The connection was lost - #[error(transparent)] + #[error("connection error: {0}")] Disconnected(#[from] ConnectionError), /// Some other error #[error(transparent)] diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 8b0733832a..00cd111da5 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -30,6 +30,7 @@ serde_with.workspace = true signal-hook.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["fs"] } +tokio-io-timeout.workspace = true tokio-postgres.workspace = true toml_edit.workspace = true tracing.workspace = true diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 22f50c3428..fb0d77a9f2 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -4,8 +4,9 @@ //! use anyhow::{Context, Result}; use postgres_backend::QueryError; -use std::{future, thread}; +use std::{future, thread, time::Duration}; use tokio::net::TcpStream; +use tokio_io_timeout::TimeoutReader; use tracing::*; use utils::measured_stream::MeasuredStream; @@ -67,41 +68,52 @@ fn handle_socket( let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let local = tokio::task::LocalSet::new(); socket.set_nodelay(true)?; let peer_addr = socket.peer_addr()?; - let traffic_metrics = TrafficMetrics::new(); - if let Some(current_az) = conf.availability_zone.as_deref() { - traffic_metrics.set_sk_az(current_az); - } + // TimeoutReader wants async runtime during creation. + runtime.block_on(async move { + // Set timeout on reading from the socket. It prevents hanged up connection + // if client suddenly disappears. Note that TCP_KEEPALIVE is not enabled by + // default, and tokio doesn't provide ability to set it out of the box. + let mut socket = TimeoutReader::new(socket); + let wal_service_timeout = Duration::from_secs(60 * 10); + socket.set_timeout(Some(wal_service_timeout)); + // pin! is here because TimeoutReader (due to storing sleep future inside) + // is not Unpin, and all pgbackend/framed/tokio dependencies require stream + // to be Unpin. Which is reasonable, as indeed something like TimeoutReader + // shouldn't be moved. + tokio::pin!(socket); - let socket = MeasuredStream::new( - socket, - |cnt| { - traffic_metrics.observe_read(cnt); - }, - |cnt| { - traffic_metrics.observe_write(cnt); - }, - ); + let traffic_metrics = TrafficMetrics::new(); + if let Some(current_az) = conf.availability_zone.as_deref() { + traffic_metrics.set_sk_az(current_az); + } - let auth_type = match conf.auth { - None => AuthType::Trust, - Some(_) => AuthType::NeonJWT, - }; - let mut conn_handler = - SafekeeperPostgresHandler::new(conf, conn_id, Some(traffic_metrics.clone())); - let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?; - // libpq protocol between safekeeper and walproposer / pageserver - // We don't use shutdown. - local.block_on( - &runtime, - pgbackend.run(&mut conn_handler, future::pending::<()>), - )?; + let socket = MeasuredStream::new( + socket, + |cnt| { + traffic_metrics.observe_read(cnt); + }, + |cnt| { + traffic_metrics.observe_write(cnt); + }, + ); - Ok(()) + let auth_type = match conf.auth { + None => AuthType::Trust, + Some(_) => AuthType::NeonJWT, + }; + let mut conn_handler = + SafekeeperPostgresHandler::new(conf, conn_id, Some(traffic_metrics.clone())); + let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?; + // libpq protocol between safekeeper and walproposer / pageserver + // We don't use shutdown. + pgbackend + .run(&mut conn_handler, future::pending::<()>) + .await + }) } /// Unique WAL service connection ids are logged in spans for observability.