diff --git a/Cargo.lock b/Cargo.lock index 246d481ef9..fbf018e1c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2502,6 +2502,7 @@ dependencies = [ "postgres-protocol", "rand", "serde", + "thiserror", "tokio", "tracing", "workspace_hack", diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index daa0b593be..b9c6a1eab0 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -13,5 +13,6 @@ rand = "0.8.3" serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.17", features = ["macros"] } tracing = "0.1" +thiserror = "1.0" workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index d31a2d51f2..c5e4dbd1f0 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -5,7 +5,7 @@ // Tools for calling certain async methods in sync contexts. pub mod sync; -use anyhow::{bail, ensure, Context, Result}; +use anyhow::{ensure, Context, Result}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_protocol::PG_EPOCH; use serde::{Deserialize, Serialize}; @@ -194,6 +194,35 @@ macro_rules! retry_read { }; } +/// An error occured during connection being open. +#[derive(thiserror::Error, Debug)] +pub enum ConnectionError { + /// IO error during writing to or reading from the connection socket. + #[error("Socket IO error: {0}")] + Socket(std::io::Error), + /// Invalid packet was received from client + #[error("Protocol error: {0}")] + Protocol(String), + /// Failed to parse a protocol mesage + #[error("Message parse error: {0}")] + MessageParse(anyhow::Error), +} + +impl From for ConnectionError { + fn from(e: anyhow::Error) -> Self { + Self::MessageParse(e) + } +} + +impl ConnectionError { + pub fn into_io_error(self) -> io::Error { + match self { + ConnectionError::Socket(io) => io, + other => io::Error::new(io::ErrorKind::Other, other.to_string()), + } + } +} + impl FeMessage { /// Read one message from the stream. /// This function returns `Ok(None)` in case of EOF. @@ -216,7 +245,9 @@ impl FeMessage { /// } /// ``` #[inline(never)] - pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result> { + pub fn read( + stream: &mut (impl io::Read + Unpin), + ) -> Result, ConnectionError> { Self::read_fut(&mut AsyncishRead(stream)).wait() } @@ -224,7 +255,7 @@ impl FeMessage { /// See documentation for `Self::read`. pub fn read_fut( stream: &mut Reader, - ) -> SyncFuture>> + '_> + ) -> SyncFuture, ConnectionError>> + '_> where Reader: tokio::io::AsyncRead + Unpin, { @@ -238,17 +269,21 @@ impl FeMessage { let tag = match retry_read!(stream.read_u8().await) { Ok(b) => b, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e.into()), + Err(e) => return Err(ConnectionError::Socket(e)), }; // The message length includes itself, so it better be at least 4. - let len = retry_read!(stream.read_u32().await)? + let len = retry_read!(stream.read_u32().await) + .map_err(ConnectionError::Socket)? .checked_sub(4) - .context("invalid message length")?; + .ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?; let body = { let mut buffer = vec![0u8; len as usize]; - stream.read_exact(&mut buffer).await?; + stream + .read_exact(&mut buffer) + .await + .map_err(ConnectionError::Socket)?; Bytes::from(buffer) }; @@ -265,7 +300,11 @@ impl FeMessage { b'c' => Ok(Some(FeMessage::CopyDone)), b'f' => Ok(Some(FeMessage::CopyFail)), b'p' => Ok(Some(FeMessage::PasswordMessage(body))), - tag => bail!("unknown message tag: {},'{:?}'", tag, body), + tag => { + return Err(ConnectionError::Protocol(format!( + "unknown message tag: {tag},'{body:?}'" + ))) + } } }) } @@ -275,7 +314,9 @@ impl FeStartupPacket { /// Read startup message from the stream. // XXX: It's tempting yet undesirable to accept `stream` by value, // since such a change will cause user-supplied &mut references to be consumed - pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result> { + pub fn read( + stream: &mut (impl io::Read + Unpin), + ) -> Result, ConnectionError> { Self::read_fut(&mut AsyncishRead(stream)).wait() } @@ -284,7 +325,7 @@ impl FeStartupPacket { // since such a change will cause user-supplied &mut references to be consumed pub fn read_fut( stream: &mut Reader, - ) -> SyncFuture>> + '_> + ) -> SyncFuture, ConnectionError>> + '_> where Reader: tokio::io::AsyncRead + Unpin, { @@ -302,31 +343,41 @@ impl FeStartupPacket { let len = match retry_read!(stream.read_u32().await) { Ok(len) => len as usize, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e.into()), + Err(e) => return Err(ConnectionError::Socket(e)), }; #[allow(clippy::manual_range_contains)] if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { - bail!("invalid message length"); + return Err(ConnectionError::Protocol(format!( + "invalid message length {len}" + ))); } - let request_code = retry_read!(stream.read_u32().await)?; + let request_code = + retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?; // the rest of startup packet are params let params_len = len - 8; let mut params_bytes = vec![0u8; params_len]; - stream.read_exact(params_bytes.as_mut()).await?; + stream + .read_exact(params_bytes.as_mut()) + .await + .map_err(ConnectionError::Socket)?; // Parse params depending on request code let req_hi = request_code >> 16; let req_lo = request_code & ((1 << 16) - 1); let message = match (req_hi, req_lo) { (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { - ensure!(params_len == 8, "expected 8 bytes for CancelRequest params"); + if params_len != 8 { + return Err(ConnectionError::Protocol( + "expected 8 bytes for CancelRequest params".to_string(), + )); + } let mut cursor = Cursor::new(params_bytes); FeStartupPacket::CancelRequest(CancelKeyData { - backend_pid: cursor.read_i32().await?, - cancel_key: cursor.read_i32().await?, + backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, + cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, }) } (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { @@ -338,7 +389,9 @@ impl FeStartupPacket { FeStartupPacket::GssEncRequest } (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { - bail!("Unrecognized request code {}", unrecognized_code) + return Err(ConnectionError::Protocol(format!( + "Unrecognized request code {unrecognized_code}" + ))); } // TODO bail if protocol major_version is not 3? (major_version, minor_version) => { @@ -346,15 +399,21 @@ impl FeStartupPacket { // See `postgres: ProcessStartupPacket, build_startup_packet`. let mut tokens = str::from_utf8(¶ms_bytes) .context("StartupMessage params: invalid utf-8")? - .strip_suffix('\0') // drop packet's own null terminator - .context("StartupMessage params: missing null terminator")? + .strip_suffix('\0') // drop packet's own null + .ok_or_else(|| { + ConnectionError::Protocol( + "StartupMessage params: missing null terminator".to_string(), + ) + })? .split_terminator('\0'); let mut params = HashMap::new(); while let Some(name) = tokens.next() { - let value = tokens - .next() - .context("StartupMessage params: key without value")?; + let value = tokens.next().ok_or_else(|| { + ConnectionError::Protocol( + "StartupMessage params: key without value".to_string(), + ) + })?; params.insert(name.to_owned(), value.to_owned()); } @@ -458,7 +517,7 @@ pub enum BeMessage<'a> { CloseComplete, // None means column is NULL DataRow(&'a [Option<&'a [u8]>]), - ErrorResponse(&'a str), + ErrorResponse(&'a str, Option<&'a [u8; 5]>), /// Single byte - used in response to SSLRequest/GSSENCRequest. EncryptionResponse(bool), NoData, @@ -606,7 +665,7 @@ fn write_body(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R { } /// Safe write of s into buf as cstring (String in the protocol). -fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), io::Error> { +fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { let bytes = s.as_ref(); if bytes.contains(&0) { return Err(io::Error::new( @@ -626,7 +685,7 @@ fn read_cstr(buf: &mut Bytes) -> anyhow::Result { Ok(result) } -const SQLSTATE_INTERNAL_ERROR: &str = "XX000\0"; +pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000"; impl<'a> BeMessage<'a> { /// Write message to the given buf. @@ -767,10 +826,7 @@ impl<'a> BeMessage<'a> { // First byte of each field represents type of this field. Set just enough fields // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error // message text. - BeMessage::ErrorResponse(error_msg) => { - // For all the errors set Severity to Error and error code to - // 'internal error'. - + BeMessage::ErrorResponse(error_msg, pg_error_code) => { // 'E' signalizes ErrorResponse messages buf.put_u8(b'E'); write_body(buf, |buf| { @@ -778,7 +834,9 @@ impl<'a> BeMessage<'a> { buf.put_slice(b"ERROR\0"); buf.put_u8(b'C'); // SQLSTATE error code - buf.put_slice(SQLSTATE_INTERNAL_ERROR.as_bytes()); + buf.put_slice(&terminate_code( + pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR), + )); buf.put_u8(b'M'); // the message write_cstr(error_msg, buf)?; @@ -801,7 +859,7 @@ impl<'a> BeMessage<'a> { buf.put_slice(b"NOTICE\0"); buf.put_u8(b'C'); // SQLSTATE error code - buf.put_slice(SQLSTATE_INTERNAL_ERROR.as_bytes()); + buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR)); buf.put_u8(b'M'); // the message write_cstr(error_msg.as_bytes(), buf)?; @@ -1089,3 +1147,12 @@ mod tests { let _ = FeStartupPacket::read_fut(stream).await; } } + +fn terminate_code(code: &[u8; 5]) -> [u8; 6] { + let mut terminated = [0; 6]; + for (i, &elem) in code.iter().enumerate() { + terminated[i] = elem; + } + + terminated +} diff --git a/libs/utils/src/postgres_backend.rs b/libs/utils/src/postgres_backend.rs index bac6f861c3..f3e3835bda 100644 --- a/libs/utils/src/postgres_backend.rs +++ b/libs/utils/src/postgres_backend.rs @@ -3,8 +3,9 @@ //! implementation determining how to process the queries. Currently its API //! is rather narrow, but we can extend it once required. +use crate::postgres_backend_async::{log_query_error, short_error, QueryError}; use crate::sock_split::{BidiStream, ReadStream, WriteStream}; -use anyhow::{bail, ensure, Context, Result}; +use anyhow::Context; use bytes::{Bytes, BytesMut}; use pq_proto::{BeMessage, FeMessage, FeStartupPacket}; use serde::{Deserialize, Serialize}; @@ -21,20 +22,32 @@ pub trait Handler { /// postgres_backend will issue ReadyForQuery after calling this (this /// might be not what we want after CopyData streaming, but currently we don't /// care). - fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()>; + fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query_string: &str, + ) -> Result<(), QueryError>; /// Called on startup packet receival, allows to process params. /// /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow /// to override whole init logic in implementations. - fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> { + fn startup( + &mut self, + _pgb: &mut PostgresBackend, + _sm: &FeStartupPacket, + ) -> Result<(), QueryError> { Ok(()) } /// Check auth jwt - fn check_auth_jwt(&mut self, _pgb: &mut PostgresBackend, _jwt_response: &[u8]) -> Result<()> { - bail!("JWT auth failed") + fn check_auth_jwt( + &mut self, + _pgb: &mut PostgresBackend, + _jwt_response: &[u8], + ) -> Result<(), QueryError> { + Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) } fn is_shutdown_requested(&self) -> bool { @@ -66,7 +79,7 @@ impl FromStr for AuthType { match s { "Trust" => Ok(Self::Trust), "NeonJWT" => Ok(Self::NeonJWT), - _ => bail!("invalid value \"{s}\" for auth type"), + _ => anyhow::bail!("invalid value \"{s}\" for auth type"), } } } @@ -154,7 +167,7 @@ pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool { } // Cast a byte slice to a string slice, dropping null terminator if there's one. -fn cstr_to_str(bytes: &[u8]) -> Result<&str> { +fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); std::str::from_utf8(without_null).map_err(|e| e.into()) } @@ -188,10 +201,10 @@ impl PostgresBackend { } /// Get direct reference (into the Option) to the read stream. - fn get_stream_in(&mut self) -> Result<&mut BidiStream> { + fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> { match &mut self.stream { Some(Stream::Bidirectional(stream)) => Ok(stream), - _ => bail!("reader taken"), + _ => anyhow::bail!("reader taken"), } } @@ -215,7 +228,7 @@ impl PostgresBackend { } /// Read full message or return None if connection is closed. - pub fn read_message(&mut self) -> Result> { + pub fn read_message(&mut self) -> Result, QueryError> { let (state, stream) = (self.state, self.get_stream_in()?); use ProtoState::*; @@ -223,6 +236,7 @@ impl PostgresBackend { Initialization | Encrypted => FeStartupPacket::read(stream), Authentication | Established => FeMessage::read(stream), } + .map_err(QueryError::from) } /// Write message into internal output buffer. @@ -246,7 +260,7 @@ impl PostgresBackend { } // Wrapper for run_message_loop() that shuts down socket when we are done - pub fn run(mut self, handler: &mut impl Handler) -> Result<()> { + pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> { let ret = self.run_message_loop(handler); if let Some(stream) = self.stream.as_mut() { let _ = stream.shutdown(Shutdown::Both); @@ -254,7 +268,7 @@ impl PostgresBackend { ret } - fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<()> { + fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> { trace!("postgres backend to {:?} started", self.peer_addr); let mut unnamed_query_string = Bytes::new(); @@ -263,7 +277,7 @@ impl PostgresBackend { match self.read_message() { Ok(message) => { if let Some(msg) = message { - trace!("got message {:?}", msg); + trace!("got message {msg:?}"); match self.process_message(handler, msg, &mut unnamed_query_string)? { ProcessMsgResult::Continue => continue, @@ -274,10 +288,12 @@ impl PostgresBackend { } } Err(e) => { - // If it is a timeout error, continue the loop - if !is_socket_read_timed_out(&e) { - return Err(e); + if let QueryError::Other(e) = &e { + if is_socket_read_timed_out(e) { + continue; + } } + return Err(e); } } } @@ -295,7 +311,7 @@ impl PostgresBackend { } stream => { self.stream = stream; - bail!("can't start TLs without bidi stream"); + anyhow::bail!("can't start TLs without bidi stream"); } } } @@ -305,17 +321,16 @@ impl PostgresBackend { handler: &mut impl Handler, msg: FeMessage, unnamed_query_string: &mut Bytes, - ) -> Result { + ) -> Result { // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth // TODO: change that to proper top-level match of protocol state with separate message handling for each state - if self.state < ProtoState::Established { - ensure!( - matches!( - msg, - FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_) - ), - "protocol violation" - ); + if self.state < ProtoState::Established + && !matches!( + msg, + FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_) + ) + { + return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); } let have_tls = self.tls_config.is_some(); @@ -339,8 +354,13 @@ impl PostgresBackend { } FeStartupPacket::StartupMessage { .. } => { if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message(&BeMessage::ErrorResponse("must connect with TLS"))?; - bail!("client did not connect with TLS"); + self.write_message(&BeMessage::ErrorResponse( + "must connect with TLS", + None, + ))?; + return Err(QueryError::Other(anyhow::anyhow!( + "client did not connect with TLS" + ))); } // NB: startup() may change self.auth_type -- we are using that in proxy code @@ -379,8 +399,11 @@ impl PostgresBackend { let (_, jwt_response) = m.split_last().context("protocol violation")?; if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?; - bail!("auth failed: {}", e); + self.write_message(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; + return Err(e); } } } @@ -394,33 +417,14 @@ impl PostgresBackend { // remove null terminator let query_string = cstr_to_str(&body)?; - trace!("got query {:?}", query_string); - // xxx distinguish fatal and recoverable errors? + trace!("got query {query_string:?}"); if let Err(e) = handler.process_query(self, query_string) { - // ":?" uses the alternate formatting style, which makes anyhow display the - // full cause of the error, not just the top-level context + its trace. - // We don't want to send that in the ErrorResponse though, - // because it's not relevant to the compute node logs. - // - // We also don't want to log full stacktrace when the error is primitive, - // such as usual connection closed. - let short_error = format!("{:#}", e); - let root_cause = e.root_cause().to_string(); - if root_cause.contains("connection closed unexpectedly") - || root_cause.contains("Broken pipe (os error 32)") - { - error!( - "query handler for '{}' failed: {}", - query_string, short_error - ); - } else { - error!("query handler for '{}' failed: {:?}", query_string, e); - } - self.write_message_noflush(&BeMessage::ErrorResponse(&short_error))?; - // TODO: untangle convoluted control flow - if e.to_string().contains("failed to run") { - return Ok(ProcessMsgResult::Break); - } + log_query_error(query_string, &e); + let short_error = short_error(&e); + self.write_message_noflush(&BeMessage::ErrorResponse( + &short_error, + Some(e.pg_error_code()), + ))?; } self.write_message(&BeMessage::ReadyForQuery)?; } @@ -445,11 +449,13 @@ impl PostgresBackend { FeMessage::Execute(_) => { let query_string = cstr_to_str(unnamed_query_string)?; - trace!("got execute {:?}", query_string); - // xxx distinguish fatal and recoverable errors? + trace!("got execute {query_string:?}"); if let Err(e) = handler.process_query(self, query_string) { - error!("query handler for '{}' failed: {:?}", query_string, e); - self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?; + log_query_error(query_string, &e); + self.write_message(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; } // NOTE there is no ReadyForQuery message. This handler is used // for basebackup and it uses CopyOut which doesn't require @@ -468,7 +474,9 @@ impl PostgresBackend { // We prefer explicit pattern matching to wildcards, because // this helps us spot the places where new variants are missing FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { - bail!("unexpected message type: {:?}", msg); + return Err(QueryError::Other(anyhow::anyhow!( + "unexpected message type: {msg:?}" + ))); } } diff --git a/libs/utils/src/postgres_backend_async.rs b/libs/utils/src/postgres_backend_async.rs index de547c3242..a4f523da04 100644 --- a/libs/utils/src/postgres_backend_async.rs +++ b/libs/utils/src/postgres_backend_async.rs @@ -4,39 +4,84 @@ //! is rather narrow, but we can extend it once required. use crate::postgres_backend::AuthType; -use anyhow::{bail, Context, Result}; +use anyhow::Context; use bytes::{Buf, Bytes, BytesMut}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket}; +use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR}; use std::future::Future; +use std::io; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; -use tracing::{debug, error, trace}; +use tracing::{debug, error, info, trace}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio_rustls::TlsAcceptor; +pub fn is_expected_io_error(e: &io::Error) -> bool { + use io::ErrorKind::*; + matches!(e.kind(), ConnectionRefused | ConnectionAborted) +} + +/// An error, occurred during query processing: +/// either during the connection ([`ConnectionError`]) or before/after it. +#[derive(thiserror::Error, Debug)] +pub enum QueryError { + /// The connection was lost while processing the query. + #[error(transparent)] + Disconnected(#[from] ConnectionError), + /// Some other error + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for QueryError { + fn from(e: io::Error) -> Self { + Self::Disconnected(ConnectionError::Socket(e)) + } +} + +impl QueryError { + pub fn pg_error_code(&self) -> &'static [u8; 5] { + match self { + Self::Disconnected(_) => b"08006", // connection failure + Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error + } + } +} + #[async_trait::async_trait] pub trait Handler { /// Handle single query. /// postgres_backend will issue ReadyForQuery after calling this (this /// might be not what we want after CopyData streaming, but currently we don't /// care). - async fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()>; + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query_string: &str, + ) -> Result<(), QueryError>; /// Called on startup packet receival, allows to process params. /// /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow /// to override whole init logic in implementations. - fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> { + fn startup( + &mut self, + _pgb: &mut PostgresBackend, + _sm: &FeStartupPacket, + ) -> Result<(), QueryError> { Ok(()) } /// Check auth jwt - fn check_auth_jwt(&mut self, _pgb: &mut PostgresBackend, _jwt_response: &[u8]) -> Result<()> { - bail!("JWT auth failed") + fn check_auth_jwt( + &mut self, + _pgb: &mut PostgresBackend, + _jwt_response: &[u8], + ) -> Result<(), QueryError> { + Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) } } @@ -70,17 +115,14 @@ impl AsyncWrite for Stream { self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> Poll> { + ) -> Poll> { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), Self::Broken => unreachable!(), } } - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx), Self::Tls(stream) => Pin::new(stream).poll_flush(cx), @@ -90,7 +132,7 @@ impl AsyncWrite for Stream { fn poll_shutdown( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> Poll> { + ) -> Poll> { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx), Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), @@ -103,7 +145,7 @@ impl AsyncRead for Stream { self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { + ) -> Poll> { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), @@ -139,7 +181,7 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec { } // Cast a byte slice to a string slice, dropping null terminator if there's one. -fn cstr_to_str(bytes: &[u8]) -> Result<&str> { +fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); std::str::from_utf8(without_null).map_err(|e| e.into()) } @@ -149,7 +191,7 @@ impl PostgresBackend { socket: tokio::net::TcpStream, auth_type: AuthType, tls_config: Option>, - ) -> std::io::Result { + ) -> io::Result { let peer_addr = socket.peer_addr()?; Ok(Self { @@ -167,17 +209,18 @@ impl PostgresBackend { } /// Read full message or return None if connection is closed. - pub async fn read_message(&mut self) -> Result> { + pub async fn read_message(&mut self) -> Result, QueryError> { use ProtoState::*; match self.state { Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await, Authentication | Established => FeMessage::read_fut(&mut self.stream).await, Closed => Ok(None), } + .map_err(QueryError::from) } /// Flush output buffer into the socket. - pub async fn flush(&mut self) -> std::io::Result<()> { + pub async fn flush(&mut self) -> io::Result<()> { while self.buf_out.has_remaining() { let bytes_written = self.stream.write(self.buf_out.chunk()).await?; self.buf_out.advance(bytes_written); @@ -187,7 +230,7 @@ impl PostgresBackend { } /// Write message into internal output buffer. - pub fn write_message(&mut self, message: &BeMessage<'_>) -> Result<&mut Self, std::io::Error> { + pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { BeMessage::write(&mut self.buf_out, message)?; Ok(self) } @@ -223,7 +266,11 @@ impl PostgresBackend { } // Wrapper for run_message_loop() that shuts down socket when we are done - pub async fn run(mut self, handler: &mut impl Handler, shutdown_watcher: F) -> Result<()> + pub async fn run( + mut self, + handler: &mut impl Handler, + shutdown_watcher: F, + ) -> Result<(), QueryError> where F: Fn() -> S, S: Future, @@ -237,7 +284,7 @@ impl PostgresBackend { &mut self, handler: &mut impl Handler, shutdown_watcher: F, - ) -> Result<()> + ) -> Result<(), QueryError> where F: Fn() -> S, S: Future, @@ -273,7 +320,7 @@ impl PostgresBackend { return Ok(()); } } - Ok::<(), anyhow::Error>(()) + Ok::<(), QueryError>(()) } => { // Handshake complete. result?; @@ -318,14 +365,14 @@ impl PostgresBackend { self.stream = Stream::Tls(Box::new(tls_stream)); return Ok(()); }; - bail!("TLS already started"); + anyhow::bail!("TLS already started"); } async fn process_handshake_message( &mut self, handler: &mut impl Handler, msg: FeMessage, - ) -> Result { + ) -> Result { assert!(self.state < ProtoState::Established); let have_tls = self.tls_config.is_some(); match msg { @@ -348,8 +395,13 @@ impl PostgresBackend { } FeStartupPacket::StartupMessage { .. } => { if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message(&BeMessage::ErrorResponse("must connect with TLS"))?; - bail!("client did not connect with TLS"); + self.write_message(&BeMessage::ErrorResponse( + "must connect with TLS", + None, + ))?; + return Err(QueryError::Other(anyhow::anyhow!( + "client did not connect with TLS" + ))); } // NB: startup() may change self.auth_type -- we are using that in proxy code @@ -389,8 +441,11 @@ impl PostgresBackend { let (_, jwt_response) = m.split_last().context("protocol violation")?; if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?; - bail!("auth failed: {}", e); + self.write_message(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; + return Err(e); } } } @@ -413,33 +468,28 @@ impl PostgresBackend { handler: &mut impl Handler, msg: FeMessage, unnamed_query_string: &mut Bytes, - ) -> Result { + ) -> Result { // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth // TODO: change that to proper top-level match of protocol state with separate message handling for each state assert!(self.state == ProtoState::Established); match msg { FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => { - bail!("protocol violation"); + return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); } FeMessage::Query(body) => { // remove null terminator let query_string = cstr_to_str(&body)?; - trace!("got query {:?}", query_string); - // xxx distinguish fatal and recoverable errors? + trace!("got query {query_string:?}"); if let Err(e) = handler.process_query(self, query_string).await { - // ":?" uses the alternate formatting style, which makes anyhow display the - // full cause of the error, not just the top-level context + its trace. - // We don't want to send that in the ErrorResponse though, - // because it's not relevant to the compute node logs. - error!("query handler for '{}' failed: {:?}", query_string, e); - self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?; - // TODO: untangle convoluted control flow - if e.to_string().contains("failed to run") { - return Ok(ProcessMsgResult::Break); - } + log_query_error(query_string, &e); + let short_error = short_error(&e); + self.write_message(&BeMessage::ErrorResponse( + &short_error, + Some(e.pg_error_code()), + ))?; } self.write_message(&BeMessage::ReadyForQuery)?; } @@ -464,11 +514,13 @@ impl PostgresBackend { FeMessage::Execute(_) => { let query_string = cstr_to_str(unnamed_query_string)?; - trace!("got execute {:?}", query_string); - // xxx distinguish fatal and recoverable errors? + trace!("got execute {query_string:?}"); if let Err(e) = handler.process_query(self, query_string).await { - error!("query handler for '{}' failed: {:?}", query_string, e); - self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?; + log_query_error(query_string, &e); + self.write_message(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; } // NOTE there is no ReadyForQuery message. This handler is used // for basebackup and it uses CopyOut which doesn't require @@ -487,7 +539,10 @@ impl PostgresBackend { // We prefer explicit pattern matching to wildcards, because // this helps us spot the places where new variants are missing FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { - bail!("unexpected message type: {:?}", msg); + return Err(QueryError::Other(anyhow::anyhow!( + "unexpected message type: {:?}", + msg + ))); } } @@ -555,3 +610,28 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> { this.pgb.poll_flush(cx) } } + +pub fn short_error(e: &QueryError) -> String { + match e { + QueryError::Disconnected(connection_error) => connection_error.to_string(), + QueryError::Other(e) => format!("{e:#}"), + } +} + +pub(super) fn log_query_error(query: &str, e: &QueryError) { + match e { + QueryError::Disconnected(ConnectionError::Socket(io_error)) => { + if is_expected_io_error(io_error) { + info!("query handler for '{query}' failed with expected io error: {io_error}"); + } else { + error!("query handler for '{query}' failed with io error: {io_error}"); + } + } + QueryError::Disconnected(other_connection_error) => { + error!("query handler for '{query}' failed with connection error: {other_connection_error:?}") + } + QueryError::Other(e) => { + error!("query handler for '{query}' failed: {e:?}"); + } + } +} diff --git a/libs/utils/tests/ssl_test.rs b/libs/utils/tests/ssl_test.rs index 248400c2c1..fae707f049 100644 --- a/libs/utils/tests/ssl_test.rs +++ b/libs/utils/tests/ssl_test.rs @@ -9,7 +9,10 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use once_cell::sync::Lazy; -use utils::postgres_backend::{AuthType, Handler, PostgresBackend}; +use utils::{ + postgres_backend::{AuthType, Handler, PostgresBackend}, + postgres_backend_async::QueryError, +}; fn make_tcp_pair() -> (TcpStream, TcpStream) { let listener = TcpListener::bind("127.0.0.1:0").unwrap(); @@ -105,7 +108,7 @@ fn ssl() { &mut self, _pgb: &mut PostgresBackend, query_string: &str, - ) -> anyhow::Result<()> { + ) -> Result<(), QueryError> { self.got_query = query_string == QUERY; Ok(()) } @@ -152,7 +155,7 @@ fn no_ssl() { &mut self, _pgb: &mut PostgresBackend, _query_string: &str, - ) -> anyhow::Result<()> { + ) -> Result<(), QueryError> { panic!() } } @@ -212,7 +215,7 @@ fn server_forces_ssl() { &mut self, _pgb: &mut PostgresBackend, _query_string: &str, - ) -> anyhow::Result<()> { + ) -> Result<(), QueryError> { panic!() } } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index f123168211..4087a8f90c 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -9,7 +9,7 @@ // custom protocol. // -use anyhow::{bail, ensure, Context, Result}; +use anyhow::Context; use bytes::Buf; use bytes::Bytes; use futures::{Stream, StreamExt}; @@ -19,6 +19,8 @@ use pageserver_api::models::{ PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamNblocksRequest, PagestreamNblocksResponse, }; +use pq_proto::ConnectionError; +use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; use std::io; use std::net::TcpListener; @@ -28,6 +30,7 @@ use std::sync::Arc; use std::time::Duration; use tracing::*; use utils::id::ConnectionId; +use utils::postgres_backend_async::QueryError; use utils::{ auth::{Claims, JwtAuth, Scope}, id::{TenantId, TimelineId}, @@ -60,8 +63,8 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { // We were requested to shut down. let msg = format!("pageserver is shutting down"); - let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg)); - Err(anyhow::anyhow!(msg)) + let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None)); + Err(QueryError::Other(anyhow::anyhow!(msg))) } msg = pgb.read_message() => { msg } @@ -74,14 +77,15 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { break }, FeMessage::Sync => continue, FeMessage::Terminate => { - let msg = format!("client terminated connection with Terminate message during COPY"); - pgb.write_message(&BeMessage::ErrorResponse(&msg))?; + let msg = "client terminated connection with Terminate message during COPY"; + let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; break; } m => { - let msg = format!("unexpected message {:?}", m); - pgb.write_message(&BeMessage::ErrorResponse(&msg))?; + let msg = format!("unexpected message {m:?}"); + pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?; Err(io::Error::new(io::ErrorKind::Other, msg))?; break; } @@ -91,12 +95,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { let msg = "client closed connection during COPY"; - pgb.write_message(&BeMessage::ErrorResponse(msg))?; + let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; pgb.flush().await?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; } - Err(e) => { - Err(io::Error::new(io::ErrorKind::Other, e))?; + Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(io_error)?; + } + Err(other) => { + Err(io::Error::new(io::ErrorKind::Other, other))?; } }; } @@ -194,23 +202,19 @@ async fn page_service_conn_main( // we've been requested to shut down Ok(()) } - Err(err) => { - let root_cause_io_err_kind = err - .root_cause() - .downcast_ref::() - .map(|e| e.kind()); - + Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { // `ConnectionReset` error happens when the Postgres client closes the connection. // As this disconnection happens quite often and is expected, // we decided to downgrade the logging level to `INFO`. // See: https://github.com/neondatabase/neon/issues/1683. - if root_cause_io_err_kind == Some(io::ErrorKind::ConnectionReset) { + if io_error.kind() == io::ErrorKind::ConnectionReset { info!("Postgres client disconnected"); Ok(()) } else { - Err(err) + Err(io_error).context("Postgres connection error") } } + other => other.context("Postgres query error"), } } @@ -312,7 +316,7 @@ impl PageServerHandler { Some(FeMessage::CopyData(bytes)) => bytes, Some(FeMessage::Terminate) => break, Some(m) => { - bail!("unexpected message: {m:?} during COPY"); + anyhow::bail!("unexpected message: {m:?} during COPY"); } None => break, // client disconnected }; @@ -369,7 +373,7 @@ impl PageServerHandler { base_lsn: Lsn, _end_lsn: Lsn, pg_version: u32, - ) -> anyhow::Result<()> { + ) -> Result<(), QueryError> { task_mgr::associate_with(Some(tenant_id), Some(timeline_id)); // Create empty timeline info!("creating new timeline"); @@ -423,11 +427,16 @@ impl PageServerHandler { timeline_id: TimelineId, start_lsn: Lsn, end_lsn: Lsn, - ) -> anyhow::Result<()> { + ) -> Result<(), QueryError> { task_mgr::associate_with(Some(tenant_id), Some(timeline_id)); let timeline = get_active_timeline_with_timeout(tenant_id, timeline_id).await?; - ensure!(timeline.get_last_record_lsn() == start_lsn); + let last_record_lsn = timeline.get_last_record_lsn(); + if last_record_lsn != start_lsn { + return Err(QueryError::Other( + anyhow::anyhow!("Cannot import WAL from Lsn {start_lsn} because timeline does not start from the same lsn: {last_record_lsn}")) + ); + } // TODO leave clean state on error. For now you can use detach to clean // up broken state from a failed import. @@ -451,7 +460,11 @@ impl PageServerHandler { } // TODO Does it make sense to overshoot? - ensure!(timeline.get_last_record_lsn() >= end_lsn); + if timeline.get_last_record_lsn() < end_lsn { + return Err(QueryError::Other( + anyhow::anyhow!("Cannot import WAL from Lsn {start_lsn} because timeline does not start from the same lsn: {last_record_lsn}")) + ); + } // Flush data to disk, then upload to s3. No need for a forced checkpoint. // We only want to persist the data, and it doesn't matter if it's in the @@ -480,7 +493,7 @@ impl PageServerHandler { mut lsn: Lsn, latest: bool, latest_gc_cutoff_lsn: &RcuReadGuard, - ) -> Result { + ) -> anyhow::Result { if latest { // Latest page version was requested. If LSN is given, it is a hint // to the page server that there have been no modifications to the @@ -511,11 +524,11 @@ impl PageServerHandler { } } else { if lsn == Lsn(0) { - bail!("invalid LSN(0) in request"); + anyhow::bail!("invalid LSN(0) in request"); } timeline.wait_lsn(lsn).await?; } - ensure!( + anyhow::ensure!( lsn >= **latest_gc_cutoff_lsn, "tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, **latest_gc_cutoff_lsn @@ -528,7 +541,7 @@ impl PageServerHandler { &self, timeline: &Timeline, req: &PagestreamExistsRequest, - ) -> Result { + ) -> anyhow::Result { let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn) .await?; @@ -548,7 +561,7 @@ impl PageServerHandler { &self, timeline: &Timeline, req: &PagestreamNblocksRequest, - ) -> Result { + ) -> anyhow::Result { let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn) .await?; @@ -568,7 +581,7 @@ impl PageServerHandler { &self, timeline: &Timeline, req: &PagestreamDbSizeRequest, - ) -> Result { + ) -> anyhow::Result { let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn) .await?; @@ -589,7 +602,7 @@ impl PageServerHandler { &self, timeline: &Timeline, req: &PagestreamGetPageRequest, - ) -> Result { + ) -> anyhow::Result { let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn) .await?; @@ -656,7 +669,7 @@ impl PageServerHandler { // when accessing management api supply None as an argument // when using to authorize tenant pass corresponding tenant id - fn check_permission(&self, tenant_id: Option) -> Result<()> { + fn check_permission(&self, tenant_id: Option) -> anyhow::Result<()> { if self.auth.is_none() { // auth is set to Trust, nothing to check so just return ok return Ok(()); @@ -678,20 +691,19 @@ impl postgres_backend_async::Handler for PageServerHandler { &mut self, _pgb: &mut PostgresBackend, jwt_response: &[u8], - ) -> anyhow::Result<()> { + ) -> Result<(), QueryError> { // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT // which requires auth to be present let data = self .auth .as_ref() .unwrap() - .decode(str::from_utf8(jwt_response)?)?; + .decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)?; - if matches!(data.claims.scope, Scope::Tenant) { - ensure!( - data.claims.tenant_id.is_some(), + if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() { + return Err(QueryError::Other(anyhow::anyhow!( "jwt token scope is Tenant, but tenant id is missing" - ) + ))); } info!( @@ -703,22 +715,33 @@ impl postgres_backend_async::Handler for PageServerHandler { Ok(()) } + fn startup( + &mut self, + _pgb: &mut PostgresBackend, + _sm: &FeStartupPacket, + ) -> Result<(), QueryError> { + Ok(()) + } + async fn process_query( &mut self, pgb: &mut PostgresBackend, query_string: &str, - ) -> anyhow::Result<()> { - debug!("process query {:?}", query_string); + ) -> Result<(), QueryError> { + debug!("process query {query_string:?}"); if query_string.starts_with("pagestream ") { let (_, params_raw) = query_string.split_at("pagestream ".len()); let params = params_raw.split(' ').collect::>(); - ensure!( - params.len() == 2, - "invalid param number for pagestream command" - ); - let tenant_id = TenantId::from_str(params[0])?; - let timeline_id = TimelineId::from_str(params[1])?; + if params.len() != 2 { + return Err(QueryError::Other(anyhow::anyhow!( + "invalid param number for pagestream command" + ))); + } + let tenant_id = TenantId::from_str(params[0]) + .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; + let timeline_id = TimelineId::from_str(params[1]) + .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; self.check_permission(Some(tenant_id))?; @@ -728,18 +751,24 @@ impl postgres_backend_async::Handler for PageServerHandler { let (_, params_raw) = query_string.split_at("basebackup ".len()); let params = params_raw.split_whitespace().collect::>(); - ensure!( - params.len() >= 2, - "invalid param number for basebackup command" - ); + if params.len() < 2 { + return Err(QueryError::Other(anyhow::anyhow!( + "invalid param number for basebackup command" + ))); + } - let tenant_id = TenantId::from_str(params[0])?; - let timeline_id = TimelineId::from_str(params[1])?; + let tenant_id = TenantId::from_str(params[0]) + .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; + let timeline_id = TimelineId::from_str(params[1]) + .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; self.check_permission(Some(tenant_id))?; let lsn = if params.len() == 3 { - Some(Lsn::from_str(params[2])?) + Some( + Lsn::from_str(params[2]) + .with_context(|| format!("Failed to parse Lsn from {}", params[2]))?, + ) } else { None }; @@ -754,13 +783,16 @@ impl postgres_backend_async::Handler for PageServerHandler { let (_, params_raw) = query_string.split_at("get_last_record_rlsn ".len()); let params = params_raw.split_whitespace().collect::>(); - ensure!( - params.len() == 2, - "invalid param number for get_last_record_rlsn command" - ); + if params.len() != 2 { + return Err(QueryError::Other(anyhow::anyhow!( + "invalid param number for get_last_record_rlsn command" + ))); + } - let tenant_id = TenantId::from_str(params[0])?; - let timeline_id = TimelineId::from_str(params[1])?; + let tenant_id = TenantId::from_str(params[0]) + .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; + let timeline_id = TimelineId::from_str(params[1]) + .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; self.check_permission(Some(tenant_id))?; let timeline = get_active_timeline_with_timeout(tenant_id, timeline_id).await?; @@ -782,22 +814,31 @@ impl postgres_backend_async::Handler for PageServerHandler { let (_, params_raw) = query_string.split_at("fullbackup ".len()); let params = params_raw.split_whitespace().collect::>(); - ensure!( - params.len() >= 2, - "invalid param number for fullbackup command" - ); + if params.len() < 2 { + return Err(QueryError::Other(anyhow::anyhow!( + "invalid param number for fullbackup command" + ))); + } - let tenant_id = TenantId::from_str(params[0])?; - let timeline_id = TimelineId::from_str(params[1])?; + let tenant_id = TenantId::from_str(params[0]) + .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; + let timeline_id = TimelineId::from_str(params[1]) + .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; // The caller is responsible for providing correct lsn and prev_lsn. let lsn = if params.len() > 2 { - Some(Lsn::from_str(params[2])?) + Some( + Lsn::from_str(params[2]) + .with_context(|| format!("Failed to parse Lsn from {}", params[2]))?, + ) } else { None }; let prev_lsn = if params.len() > 3 { - Some(Lsn::from_str(params[3])?) + Some( + Lsn::from_str(params[3]) + .with_context(|| format!("Failed to parse Lsn from {}", params[3]))?, + ) } else { None }; @@ -822,12 +863,21 @@ impl postgres_backend_async::Handler for PageServerHandler { // -c "import basebackup $TENANT $TIMELINE $START_LSN $END_LSN $PG_VERSION" let (_, params_raw) = query_string.split_at("import basebackup ".len()); let params = params_raw.split_whitespace().collect::>(); - ensure!(params.len() == 5); - let tenant_id = TenantId::from_str(params[0])?; - let timeline_id = TimelineId::from_str(params[1])?; - let base_lsn = Lsn::from_str(params[2])?; - let end_lsn = Lsn::from_str(params[3])?; - let pg_version = u32::from_str(params[4])?; + if params.len() != 5 { + return Err(QueryError::Other(anyhow::anyhow!( + "invalid param number for import basebackup command" + ))); + } + let tenant_id = TenantId::from_str(params[0]) + .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; + let timeline_id = TimelineId::from_str(params[1]) + .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; + let base_lsn = Lsn::from_str(params[2]) + .with_context(|| format!("Failed to parse Lsn from {}", params[2]))?; + let end_lsn = Lsn::from_str(params[3]) + .with_context(|| format!("Failed to parse Lsn from {}", params[3]))?; + let pg_version = u32::from_str(params[4]) + .with_context(|| format!("Failed to parse pg_version from {}", params[4]))?; self.check_permission(Some(tenant_id))?; @@ -845,7 +895,10 @@ impl postgres_backend_async::Handler for PageServerHandler { Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?, Err(e) => { error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}"); - pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))? + pgb.write_message(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))? } }; } else if query_string.starts_with("import wal ") { @@ -855,11 +908,19 @@ impl postgres_backend_async::Handler for PageServerHandler { // caller should poll the http api to check when that is done. let (_, params_raw) = query_string.split_at("import wal ".len()); let params = params_raw.split_whitespace().collect::>(); - ensure!(params.len() == 4); - let tenant_id = TenantId::from_str(params[0])?; - let timeline_id = TimelineId::from_str(params[1])?; - let start_lsn = Lsn::from_str(params[2])?; - let end_lsn = Lsn::from_str(params[3])?; + if params.len() != 4 { + return Err(QueryError::Other(anyhow::anyhow!( + "invalid param number for import wal command" + ))); + } + let tenant_id = TenantId::from_str(params[0]) + .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; + let timeline_id = TimelineId::from_str(params[1]) + .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; + let start_lsn = Lsn::from_str(params[2]) + .with_context(|| format!("Failed to parse Lsn from {}", params[2]))?; + let end_lsn = Lsn::from_str(params[3]) + .with_context(|| format!("Failed to parse Lsn from {}", params[3]))?; self.check_permission(Some(tenant_id))?; @@ -870,7 +931,10 @@ impl postgres_backend_async::Handler for PageServerHandler { Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?, Err(e) => { error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}"); - pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))? + pgb.write_message(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))? } }; } else if query_string.to_ascii_lowercase().starts_with("set ") { @@ -881,8 +945,13 @@ impl postgres_backend_async::Handler for PageServerHandler { // show let (_, params_raw) = query_string.split_at("show ".len()); let params = params_raw.split(' ').collect::>(); - ensure!(params.len() == 1, "invalid param number for config command"); - let tenant_id = TenantId::from_str(params[0])?; + if params.len() != 1 { + return Err(QueryError::Other(anyhow::anyhow!( + "invalid param number for config command" + ))); + } + let tenant_id = TenantId::from_str(params[0]) + .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; self.check_permission(Some(tenant_id))?; @@ -923,7 +992,9 @@ impl postgres_backend_async::Handler for PageServerHandler { ]))? .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; } else { - bail!("unknown command"); + return Err(QueryError::Other(anyhow::anyhow!( + "unknown command {query_string}" + ))); } Ok(()) @@ -935,7 +1006,7 @@ impl postgres_backend_async::Handler for PageServerHandler { /// If the tenant is Loading, waits for it to become Active, for up to 30 s. That /// ensures that queries don't fail immediately after pageserver startup, because /// all tenants are still loading. -async fn get_active_tenant_with_timeout(tenant_id: TenantId) -> Result> { +async fn get_active_tenant_with_timeout(tenant_id: TenantId) -> anyhow::Result> { let tenant = mgr::get_tenant(tenant_id, false).await?; match tokio::time::timeout(Duration::from_secs(30), tenant.wait_to_become_active()).await { Ok(wait_result) => wait_result @@ -949,7 +1020,7 @@ async fn get_active_tenant_with_timeout(tenant_id: TenantId) -> Result Result> { +) -> anyhow::Result> { get_active_tenant_with_timeout(tenant_id) .await .and_then(|tenant| tenant.get_timeline(timeline_id, true)) diff --git a/pageserver/src/walreceiver/walreceiver_connection.rs b/pageserver/src/walreceiver/walreceiver_connection.rs index 06aa132365..aca5e8e019 100644 --- a/pageserver/src/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/walreceiver/walreceiver_connection.rs @@ -1,6 +1,7 @@ //! Actual Postgres connection handler to stream WAL to the server. use std::{ + error::Error, str::FromStr, sync::Arc, time::{Duration, SystemTime}, @@ -11,7 +12,7 @@ use bytes::BytesMut; use chrono::{NaiveDateTime, Utc}; use fail::fail_point; use futures::StreamExt; -use postgres::{SimpleQueryMessage, SimpleQueryRow}; +use postgres::{error::SqlState, SimpleQueryMessage, SimpleQueryRow}; use postgres_ffi::v14::xlog_utils::normalize_lsn; use postgres_ffi::WAL_SEGMENT_SIZE; use postgres_protocol::message::backend::ReplicationMessage; @@ -32,7 +33,7 @@ use crate::{ use postgres_connection::PgConnectionConfig; use postgres_ffi::waldecoder::WalStreamDecoder; use pq_proto::ReplicationFeedback; -use utils::lsn::Lsn; +use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error}; /// Status of the connection. #[derive(Debug, Clone, Copy)] @@ -68,10 +69,17 @@ pub async fn handle_walreceiver_connection( let mut config = wal_source_connconf.to_tokio_postgres_config(); config.application_name("pageserver"); config.replication_mode(tokio_postgres::config::ReplicationMode::Physical); - time::timeout(connect_timeout, config.connect(postgres::NoTls)) - .await - .context("Timed out while waiting for walreceiver connection to open")? - .context("Failed to open walreceiver connection")? + match time::timeout(connect_timeout, config.connect(postgres::NoTls)).await { + Ok(Ok(client_and_conn)) => client_and_conn, + Ok(Err(conn_err)) => { + let expected_error = ignore_expected_errors(conn_err)?; + info!("DB connection stream finished: {expected_error}"); + return Ok(()); + } + Err(elapsed) => anyhow::bail!( + "Timed out while waiting {elapsed} for walreceiver connection to open" + ), + } }; info!("connected!"); @@ -103,10 +111,8 @@ pub async fn handle_walreceiver_connection( connection_result = connection => match connection_result{ Ok(()) => info!("Walreceiver db connection closed"), Err(connection_error) => { - if connection_error.is_closed() { - info!("Connection closed regularly: {connection_error}") - } else { - warn!("Connection aborted: {connection_error}") + if let Err(e) = ignore_expected_errors(connection_error) { + warn!("Connection aborted: {e:#}") } } }, @@ -187,14 +193,9 @@ pub async fn handle_walreceiver_connection( let replication_message = match replication_message { Ok(message) => message, Err(replication_error) => { - if replication_error.is_closed() { - info!("Replication stream got closed"); - return Ok(()); - } else { - return Err( - anyhow::Error::new(replication_error).context("replication stream error") - ); - } + let expected_error = ignore_expected_errors(replication_error)?; + info!("Replication stream finished: {expected_error}"); + return Ok(()); } }; @@ -400,3 +401,32 @@ async fn identify_system(client: &mut Client) -> anyhow::Result Err(IdentifyError.into()) } } + +/// We don't want to report connectivity problems as real errors towards connection manager because +/// 1. they happen frequently enough to make server logs hard to read and +/// 2. the connection manager can retry other safekeeper. +/// +/// If this function returns `Ok(pg_error)`, it's such an error. +/// The caller should log it at info level and then report to connection manager that we're done handling this connection. +/// Connection manager will then handle reconnections. +/// +/// If this function returns an `Err()`, the caller can bubble it up using `?`. +/// The connection manager will log the error at ERROR level. +fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result { + if pg_error.is_closed() + || pg_error + .source() + .and_then(|source| source.downcast_ref::()) + .map(is_expected_io_error) + .unwrap_or(false) + { + return Ok(pg_error); + } else if let Some(db_error) = pg_error.as_db_error() { + if db_error.code() == &SqlState::CONNECTION_FAILURE + && db_error.message().contains("end streaming") + { + return Ok(pg_error); + } + } + Err(pg_error).context("connection error") +} diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 2e0a502e7f..cf83b48ae0 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -9,7 +9,10 @@ use std::{ thread, }; use tracing::{error, info, info_span}; -use utils::postgres_backend::{self, AuthType, PostgresBackend}; +use utils::{ + postgres_backend::{self, AuthType, PostgresBackend}, + postgres_backend_async::QueryError, +}; /// Console management API listener thread. /// It spawns console response handlers needed for the link auth. @@ -47,7 +50,7 @@ pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> { } } -fn handle_connection(socket: TcpStream) -> anyhow::Result<()> { +fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?; pgbackend.run(&mut MgmtHandler) } @@ -58,7 +61,7 @@ pub type ComputeReady = Result; // TODO: replace with an http-based protocol. struct MgmtHandler; impl postgres_backend::Handler for MgmtHandler { - fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> { + fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { try_process_query(pgb, query).map_err(|e| { error!("failed to process response: {e:?}"); e @@ -66,8 +69,8 @@ impl postgres_backend::Handler for MgmtHandler { } } -fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> { - let resp: KickSession = serde_json::from_str(query)?; +fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { + let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?; let span = info_span!("event", session_id = resp.session_id); let _enter = span.enter(); @@ -81,7 +84,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<( } Err(e) => { error!("failed to deliver response to per-client task"); - pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))?; + pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 19e1479068..02a0fabe9a 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,7 +2,7 @@ use crate::error::UserFacingError; use anyhow::bail; use bytes::BytesMut; use pin_project_lite::pin_project; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket}; +use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket}; use rustls::ServerConfig; use std::pin::Pin; use std::sync::Arc; @@ -47,18 +47,13 @@ fn err_connection() -> io::Error { io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") } -// TODO: change error type of `FeMessage::read_fut` -fn from_anyhow(e: anyhow::Error) -> io::Error { - io::Error::new(io::ErrorKind::Other, e.to_string()) -} - impl PqStream { /// Receive [`FeStartupPacket`], which is a first packet sent by a client. pub async fn read_startup_packet(&mut self) -> io::Result { // TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket` let msg = FeStartupPacket::read_fut(&mut self.stream) .await - .map_err(from_anyhow)? + .map_err(ConnectionError::into_io_error)? .ok_or_else(err_connection)?; match msg { @@ -80,7 +75,7 @@ impl PqStream { async fn read_message(&mut self) -> io::Result { FeMessage::read_fut(&mut self.stream) .await - .map_err(from_anyhow)? + .map_err(ConnectionError::into_io_error)? .ok_or_else(err_connection) } } @@ -112,7 +107,8 @@ impl PqStream { /// This method exists due to `&str` not implementing `Into`. pub async fn throw_error_str(&mut self, error: &'static str) -> anyhow::Result { tracing::info!("forwarding error to user: {error}"); - self.write_message(&BeMessage::ErrorResponse(error)).await?; + self.write_message(&BeMessage::ErrorResponse(error, None)) + .await?; bail!(error) } @@ -124,7 +120,8 @@ impl PqStream { { let msg = error.to_string_client(); tracing::info!("forwarding error to user: {msg}"); - self.write_message(&BeMessage::ErrorResponse(&msg)).await?; + self.write_message(&BeMessage::ErrorResponse(&msg, None)) + .await?; bail!(error) } } diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 394a4815bb..b130ea86bd 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -229,11 +229,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { let conf_cloned = conf.clone(); let safekeeper_thread = thread::Builder::new() .name("safekeeper thread".into()) - .spawn(|| { - if let Err(e) = wal_service::thread_main(conf_cloned, pg_listener) { - info!("safekeeper thread terminated: {e}"); - } - }) + .spawn(|| wal_service::thread_main(conf_cloned, pg_listener)) .unwrap(); threads.push(safekeeper_thread); diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index c692e9fc12..60df5dd372 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -8,7 +8,7 @@ use crate::receive_wal::ReceiveWalConn; use crate::send_wal::ReplicationConn; use crate::{GlobalTimelines, SafeKeeperConf}; -use anyhow::{bail, ensure, Context, Result}; +use anyhow::Context; use postgres_ffi::PG_TLI; use regex::Regex; @@ -17,6 +17,7 @@ use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID}; use std::str; use tracing::info; use utils::auth::{Claims, Scope}; +use utils::postgres_backend_async::QueryError; use utils::{ id::{TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, @@ -42,7 +43,7 @@ enum SafekeeperPostgresCommand { JSONCtrl { cmd: AppendLogicalMessage }, } -fn parse_cmd(cmd: &str) -> Result { +fn parse_cmd(cmd: &str) -> anyhow::Result { if cmd.starts_with("START_WAL_PUSH") { Ok(SafekeeperPostgresCommand::StartWalPush) } else if cmd.starts_with("START_REPLICATION") { @@ -62,13 +63,17 @@ fn parse_cmd(cmd: &str) -> Result { cmd: serde_json::from_str(cmd)?, }) } else { - bail!("unsupported command {}", cmd); + anyhow::bail!("unsupported command {cmd}"); } } impl postgres_backend::Handler for SafekeeperPostgresHandler { // tenant_id and timeline_id are passed in connection string params - fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> { + fn startup( + &mut self, + _pgb: &mut PostgresBackend, + sm: &FeStartupPacket, + ) -> Result<(), QueryError> { if let FeStartupPacket::StartupMessage { params, .. } = sm { if let Some(options) = params.options_raw() { for opt in options { @@ -77,10 +82,14 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064 match opt.split_once('=') { Some(("ztenantid", value)) | Some(("tenant_id", value)) => { - self.tenant_id = Some(value.parse()?); + self.tenant_id = Some(value.parse().with_context(|| { + format!("Failed to parse {value} as tenant id") + })?); } Some(("ztimelineid", value)) | Some(("timeline_id", value)) => { - self.timeline_id = Some(value.parse()?); + self.timeline_id = Some(value.parse().with_context(|| { + format!("Failed to parse {value} as timeline id") + })?); } _ => continue, } @@ -93,7 +102,9 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { Ok(()) } else { - bail!("Safekeeper received unexpected initial message: {:?}", sm); + Err(QueryError::Other(anyhow::anyhow!( + "Safekeeper received unexpected initial message: {sm:?}" + ))) } } @@ -101,7 +112,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { &mut self, _pgb: &mut PostgresBackend, jwt_response: &[u8], - ) -> anyhow::Result<()> { + ) -> Result<(), QueryError> { // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT // which requires auth to be present let data = self @@ -109,13 +120,12 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { .auth .as_ref() .unwrap() - .decode(str::from_utf8(jwt_response)?)?; + .decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)?; - if matches!(data.claims.scope, Scope::Tenant) { - ensure!( - data.claims.tenant_id.is_some(), + if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() { + return Err(QueryError::Other(anyhow::anyhow!( "jwt token scope is Tenant, but tenant id is missing" - ) + ))); } info!( @@ -127,7 +137,11 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { Ok(()) } - fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()> { + fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query_string: &str, + ) -> Result<(), QueryError> { if query_string .to_ascii_lowercase() .starts_with("set datestyle to ") @@ -148,19 +162,26 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { self.check_permission(Some(tenant_id))?; self.ttid = TenantTimelineId::new(tenant_id, timeline_id); - match cmd { + let res = match cmd { SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self), SafekeeperPostgresCommand::StartReplication { start_lsn } => { ReplicationConn::new(pgb).run(self, pgb, start_lsn) } SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb), SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd), - } - .context(format!( - "Failed to process query for timeline {timeline_id}" - ))?; + }; - Ok(()) + match res { + Ok(()) => Ok(()), + Err(QueryError::Disconnected(connection_error)) => { + info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}"); + Err(QueryError::Disconnected(connection_error)) + } + Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!( + "Failed to process query for timeline {}", + self.ttid + )))), + } } } @@ -178,7 +199,7 @@ impl SafekeeperPostgresHandler { // when accessing management api supply None as an argument // when using to authorize tenant pass corresponding tenant id - fn check_permission(&self, tenant_id: Option) -> Result<()> { + fn check_permission(&self, tenant_id: Option) -> anyhow::Result<()> { if self.conf.auth.is_none() { // auth is set to Trust, nothing to check so just return ok return Ok(()); @@ -196,7 +217,7 @@ impl SafekeeperPostgresHandler { /// /// Handle IDENTIFY_SYSTEM replication command /// - fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<()> { + fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> { let tli = GlobalTimelines::get(self.ttid)?; let lsn = if self.is_walproposer_recovery() { diff --git a/safekeeper/src/json_ctrl.rs b/safekeeper/src/json_ctrl.rs index 746b4461b7..32a24a4978 100644 --- a/safekeeper/src/json_ctrl.rs +++ b/safekeeper/src/json_ctrl.rs @@ -8,11 +8,12 @@ use std::sync::Arc; -use anyhow::Result; +use anyhow::Context; use bytes::Bytes; use serde::{Deserialize, Serialize}; use tracing::*; use utils::id::TenantTimelineId; +use utils::postgres_backend_async::QueryError; use crate::handler::SafekeeperPostgresHandler; use crate::safekeeper::{AcceptorProposerMessage, AppendResponse, ServerInfo}; @@ -47,7 +48,7 @@ pub struct AppendLogicalMessage { pg_version: u32, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] struct AppendResult { // safekeeper state after append state: SafeKeeperState, @@ -62,8 +63,8 @@ pub fn handle_json_ctrl( spg: &SafekeeperPostgresHandler, pgb: &mut PostgresBackend, append_request: &AppendLogicalMessage, -) -> Result<()> { - info!("JSON_CTRL request: {:?}", append_request); +) -> Result<(), QueryError> { + info!("JSON_CTRL request: {append_request:?}"); // need to init safekeeper state before AppendRequest let tli = prepare_safekeeper(spg.ttid, append_request.pg_version)?; @@ -78,7 +79,8 @@ pub fn handle_json_ctrl( state: tli.get_state().1, inserted_wal, }; - let response_data = serde_json::to_vec(&response)?; + let response_data = serde_json::to_vec(&response) + .with_context(|| format!("Response {response:?} is not a json array"))?; pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor { name: b"json", @@ -93,7 +95,7 @@ pub fn handle_json_ctrl( /// Prepare safekeeper to process append requests without crashes, /// by sending ProposerGreeting with default server.wal_seg_size. -fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> Result> { +fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result> { GlobalTimelines::create( ttid, ServerInfo { @@ -106,7 +108,7 @@ fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> Result, term: Term, lsn: Lsn) -> Result<()> { +fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> anyhow::Result<()> { // add new term to existing history let history = tli.get_state().1.acceptor_state.term_history; let history = history.up_to(lsn.checked_sub(1u64).unwrap()); @@ -125,7 +127,7 @@ fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> Result<() Ok(()) } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] struct InsertedWAL { begin_lsn: Lsn, end_lsn: Lsn, @@ -134,7 +136,10 @@ struct InsertedWAL { /// Extend local WAL with new LogicalMessage record. To do that, /// create AppendRequest with new WAL and pass it to safekeeper. -fn append_logical_message(tli: &Arc, msg: &AppendLogicalMessage) -> Result { +fn append_logical_message( + tli: &Arc, + msg: &AppendLogicalMessage, +) -> anyhow::Result { let wal_data = encode_logical_message(&msg.lm_prefix, &msg.lm_message); let sk_state = tli.get_state().1; diff --git a/safekeeper/src/receive_wal.rs b/safekeeper/src/receive_wal.rs index be7f071abb..671e5470a0 100644 --- a/safekeeper/src/receive_wal.rs +++ b/safekeeper/src/receive_wal.rs @@ -2,11 +2,13 @@ //! Gets messages from the network, passes them down to consensus module and //! sends replies back. -use anyhow::{anyhow, bail, Result}; +use anyhow::anyhow; +use anyhow::Context; use bytes::BytesMut; use tracing::*; use utils::lsn::Lsn; +use utils::postgres_backend_async::QueryError; use crate::safekeeper::ServerInfo; use crate::timeline::Timeline; @@ -43,7 +45,7 @@ impl<'pg> ReceiveWalConn<'pg> { } // Send message to the postgres - fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> Result<()> { + fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> { let mut buf = BytesMut::with_capacity(128); msg.serialize(&mut buf)?; self.pg_backend.write_message(&BeMessage::CopyData(&buf))?; @@ -51,7 +53,7 @@ impl<'pg> ReceiveWalConn<'pg> { } /// Receive WAL from wal_proposer - pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<()> { + pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> { let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered(); // Notify the libpq client that it's allowed to send `CopyData` messages @@ -79,7 +81,11 @@ impl<'pg> ReceiveWalConn<'pg> { }; GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)? } - _ => bail!("unexpected message {:?} instead of greeting", next_msg), + _ => { + return Err(QueryError::Other(anyhow::anyhow!( + "unexpected message {next_msg:?} instead of greeting" + ))) + } }; let mut next_msg = Some(next_msg); @@ -134,25 +140,32 @@ impl<'pg> ReceiveWalConn<'pg> { struct ProposerPollStream { msg_rx: Receiver, - read_thread: Option>>, + read_thread: Option>>, } impl ProposerPollStream { - fn new(mut r: ReadStream) -> Result { + fn new(mut r: ReadStream) -> anyhow::Result { let (msg_tx, msg_rx) = channel(); let read_thread = thread::Builder::new() .name("Read WAL thread".into()) - .spawn(move || -> Result<()> { + .spawn(move || -> Result<(), QueryError> { loop { let copy_data = match FeMessage::read(&mut r)? { - Some(FeMessage::CopyData(bytes)) => bytes, - Some(msg) => bail!("expected `CopyData` message, found {:?}", msg), - None => bail!("connection closed unexpectedly"), - }; + Some(FeMessage::CopyData(bytes)) => Ok(bytes), + Some(msg) => Err(QueryError::Other(anyhow::anyhow!( + "expected `CopyData` message, found {msg:?}" + ))), + None => Err(QueryError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "walproposer closed the connection", + ))), + }?; let msg = ProposerAcceptorMessage::parse(copy_data)?; - msg_tx.send(msg)?; + msg_tx + .send(msg) + .context("Failed to send the proposer message")?; } // msg_tx will be dropped here, this will also close msg_rx })?; @@ -163,17 +176,19 @@ impl ProposerPollStream { }) } - fn recv_msg(&mut self) -> Result { + fn recv_msg(&mut self) -> Result { self.msg_rx.recv().map_err(|_| { // return error from the read thread let res = match self.read_thread.take() { Some(thread) => thread.join(), - None => return anyhow!("read thread is gone"), + None => return QueryError::Other(anyhow::anyhow!("read thread is gone")), }; match res { - Ok(Ok(())) => anyhow!("unexpected result from read thread"), - Err(err) => anyhow!("read thread panicked: {:?}", err), + Ok(Ok(())) => { + QueryError::Other(anyhow::anyhow!("unexpected result from read thread")) + } + Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")), Ok(Err(err)) => err, } }) diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index a054b8fe14..20600ab694 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -5,7 +5,7 @@ use crate::handler::SafekeeperPostgresHandler; use crate::timeline::{ReplicaState, Timeline}; use crate::wal_storage::WalReader; use crate::GlobalTimelines; -use anyhow::{bail, Context, Result}; +use anyhow::Context; use bytes::Bytes; use postgres_ffi::get_current_timestamp; @@ -15,7 +15,8 @@ use std::cmp::min; use std::net::Shutdown; use std::sync::Arc; use std::time::Duration; -use std::{str, thread}; +use std::{io, str, thread}; +use utils::postgres_backend_async::QueryError; use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; use tokio::sync::watch::Receiver; @@ -91,7 +92,7 @@ impl ReplicationConn { fn background_thread( mut stream_in: ReadStream, replica_guard: Arc, - ) -> Result<()> { + ) -> anyhow::Result<()> { let replica_id = replica_guard.replica; let timeline = &replica_guard.timeline; @@ -140,7 +141,7 @@ impl ReplicationConn { // Shutdown the connection, because rust-postgres client cannot be dropped // when connection is alive. let _ = stream_in.shutdown(Shutdown::Both); - bail!("Copy failed"); + anyhow::bail!("Copy failed"); } _ => { // We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored. @@ -160,7 +161,7 @@ impl ReplicationConn { spg: &mut SafekeeperPostgresHandler, pgb: &mut PostgresBackend, mut start_pos: Lsn, - ) -> Result<()> { + ) -> Result<(), QueryError> { let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered(); let tli = GlobalTimelines::get(spg.ttid)?; @@ -256,8 +257,10 @@ impl ReplicationConn { // to right pageserver. if tli.should_walsender_stop(replica_id) { // Shut down, timeline is suspended. - // TODO create proper error type for this - bail!("end streaming to {:?}", spg.appname); + return Err(QueryError::from(io::Error::new( + io::ErrorKind::ConnectionAborted, + format!("end streaming to {:?}", spg.appname), + ))); } // timeout expired: request pageserver status @@ -265,8 +268,7 @@ impl ReplicationConn { sent_ptr: end_pos.0, timestamp: get_current_timestamp(), request_reply: true, - })) - .context("Failed to send KeepAlive message")?; + }))?; continue; } } @@ -301,7 +303,7 @@ impl ReplicationConn { const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); // Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn. -async fn wait_for_lsn(rx: &mut Receiver, lsn: Lsn) -> Result> { +async fn wait_for_lsn(rx: &mut Receiver, lsn: Lsn) -> anyhow::Result> { let commit_lsn: Lsn = *rx.borrow(); if commit_lsn > lsn { return Ok(Some(commit_lsn)); diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 0fea00fe1b..3ca651d060 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -2,18 +2,18 @@ //! WAL service listens for client connections and //! receive WAL from wal_proposer and send it to WAL receivers //! -use anyhow::Result; use regex::Regex; use std::net::{TcpListener, TcpStream}; use std::thread; use tracing::*; +use utils::postgres_backend_async::QueryError; use crate::handler::SafekeeperPostgresHandler; use crate::SafeKeeperConf; use utils::postgres_backend::{AuthType, PostgresBackend}; /// Accept incoming TCP connections and spawn them into a background thread. -pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> Result<()> { +pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! { loop { match listener.accept() { Ok((socket, peer_addr)) => { @@ -44,7 +44,7 @@ fn get_tid() -> u64 { /// This is run by `thread_main` above, inside a background thread. /// -fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<()> { +fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> { let _enter = info_span!("", tid = ?get_tid()).entered(); socket.set_nodelay(true)?; diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 705ab70ab4..eb15278ba7 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1903,13 +1903,15 @@ class NeonPageserver(PgProtocol): ".*wal receiver task finished with an error: walreceiver connection handling failure.*", ".*Shutdown task error: walreceiver connection handling failure.*", ".*wal_connection_manager.*tcp connect error: Connection refused.*", - ".*query handler for .* failed: Connection reset by peer.*", - ".*serving compute connection task.*exited with error: Broken pipe.*", - ".*Connection aborted: error communicating with the server: Broken pipe.*", - ".*Connection aborted: error communicating with the server: Transport endpoint is not connected.*", - ".*Connection aborted: error communicating with the server: Connection reset by peer.*", + ".*query handler for .* failed: Socket IO error: Connection reset by peer.*", + ".*serving compute connection task.*exited with error: Postgres connection error.*", + ".*serving compute connection task.*exited with error: Connection reset by peer.*", + ".*serving compute connection task.*exited with error: Postgres query error.*", + ".*Connection aborted: connection error: error communicating with the server: Broken pipe.*", + ".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*", + ".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*", ".*kill_and_wait_impl.*: wait successful.*", - ".*end streaming to Some.*", + ".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*", ".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down # safekeeper connection can fail with this, in the window between timeline creation # and streaming start diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 77ec33f8b0..72d27c3aba 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -1105,7 +1105,6 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): env.pageserver.allowed_errors.extend( [ ".*Failed to process query for timeline .*: Timeline .* was not found in global map.*", - ".*end streaming to Some.*", ] )