diff --git a/Cargo.lock b/Cargo.lock index 1649e28faa..b7b83f3b37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2510,6 +2510,7 @@ name = "pq_proto" version = "0.1.0" dependencies = [ "anyhow", + "byteorder", "bytes", "pin-project-lite", "postgres-protocol", @@ -2517,6 +2518,7 @@ dependencies = [ "serde", "thiserror", "tokio", + "tokio-util", "tracing", "workspace_hack", ] @@ -3074,6 +3076,7 @@ dependencies = [ "const_format", "crc32c", "fs2", + "futures", "git-version", "hex", "humantime", @@ -3082,6 +3085,7 @@ dependencies = [ "nix", "once_cell", "parking_lot", + "pin-project-lite", "postgres", "postgres-protocol", "postgres_ffi", @@ -4203,6 +4207,7 @@ dependencies = [ "byteorder", "bytes", "criterion", + "futures", "git-version", "hex", "hex-literal", @@ -4211,6 +4216,7 @@ dependencies = [ "metrics", "nix", "once_cell", + "pin-utils", "pq_proto", "rand", "routerify", @@ -4228,6 +4234,7 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", + "tokio-util", "tracing", "tracing-subscriber", "workspace_hack", diff --git a/control_plane/src/compute.rs b/control_plane/src/compute.rs index 8731cf2583..37402d735f 100644 --- a/control_plane/src/compute.rs +++ b/control_plane/src/compute.rs @@ -14,7 +14,7 @@ use anyhow::{Context, Result}; use utils::{ id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, + postgres_backend_async::AuthType, }; use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION}; diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 003152c578..6da6b137c7 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -19,7 +19,7 @@ use std::process::{Command, Stdio}; use utils::{ auth::{encode_from_key_file, Claims, Scope}, id::{NodeId, TenantId, TenantTimelineId, TimelineId}, - postgres_backend::AuthType, + postgres_backend_async::AuthType, }; use crate::safekeeper::SafekeeperNode; diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index b9c6a1eab0..83e15cbd61 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -7,11 +7,13 @@ license = "Apache-2.0" [dependencies] anyhow = "1.0" bytes = "1.0.1" +byteorder = "1.4.3" pin-project-lite = "0.2.7" postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" } rand = "0.8.3" serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.17", features = ["macros"] } +tokio-util = { version = "0.7.3" } tracing = "0.1" thiserror = "1.0" diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index c5e4dbd1f0..90616a9090 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -3,9 +3,11 @@ //! on message formats. // Tools for calling certain async methods in sync contexts. +pub mod codec; pub mod sync; -use anyhow::{ensure, Context, Result}; +use anyhow::{anyhow, bail, ensure, Context, Result}; +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_protocol::PG_EPOCH; use serde::{Deserialize, Serialize}; @@ -19,7 +21,7 @@ use std::{ time::{Duration, SystemTime}, }; use sync::{AsyncishRead, SyncFuture}; -use tokio::io::AsyncReadExt; +// use tokio::io::AsyncReadExt; use tracing::{trace, warn}; pub type Oid = u32; @@ -194,36 +196,108 @@ macro_rules! retry_read { }; } -/// An error occured during connection being open. +/// An error occured while parsing or serializing raw stream into Postgres +/// messages. #[derive(thiserror::Error, Debug)] -pub enum ConnectionError { +pub enum ProtocolError { /// IO error during writing to or reading from the connection socket. + /// removeme #[error("Socket IO error: {0}")] Socket(std::io::Error), - /// Invalid packet was received from client + /// Invalid packet was received from the client (e.g. unexpected message + /// type or broken len). #[error("Protocol error: {0}")] Protocol(String), - /// Failed to parse a protocol mesage + /// Failed to parse or, (unlikely), serialize a protocol message. #[error("Message parse error: {0}")] MessageParse(anyhow::Error), } -impl From for ConnectionError { +// Allows to return anyhow error from msg parsing routines, meaning less typing. +impl From for ProtocolError { fn from(e: anyhow::Error) -> Self { Self::MessageParse(e) } } -impl ConnectionError { +impl ProtocolError { pub fn into_io_error(self) -> io::Error { match self { - ConnectionError::Socket(io) => io, + ProtocolError::Socket(io) => io, other => io::Error::new(io::ErrorKind::Other, other.to_string()), } } } impl FeMessage { + /// Read and parse one message from the `buf` input buffer. If there is at + /// least one valid message, returns it, advancing `buf`; redundant copies + /// are avoided, as thanks to `bytes` crate ptrs in parsed message point + /// directly into the `buf` (processed data is garbage collected after + /// parsed message is dropped). + /// + /// Returns None if `buf` doesn't contain enough data for a single message. + /// For efficiency, tries to reserve large enough space in `buf` for the + /// next message in this case. + /// + /// Returns Error if message is malformed, the only possible ErrorKind is + /// InvalidInput. + // + // Inspired by rust-postgres Message::parse. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { + // Every message contains message type byte and 4 bytes len; can't do + // much without them. + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let tag = buf[0]; + let len = (&buf[1..5]).read_u32::().unwrap(); + if len < 4 { + return Err(ProtocolError::Protocol(format!( + "invalid message length {}", + len + ))); + } + + // lenth field includes itself, but not message type. + let total_len = len as usize + 1; + if buf.len() < total_len { + // Don't have full message yet. + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + // got the message, advance buffer + let mut msg = buf.split_to(total_len).freeze(); + msg.advance(5); // consume message type and len + + match tag { + b'Q' => Ok(Some(FeMessage::Query(msg))), + b'P' => Ok(Some(FeParseMessage::parse(msg)?)), + b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)), + b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)), + b'B' => Ok(Some(FeBindMessage::parse(msg)?)), + b'C' => Ok(Some(FeCloseMessage::parse(msg)?)), + b'S' => Ok(Some(FeMessage::Sync)), + b'X' => Ok(Some(FeMessage::Terminate)), + b'd' => Ok(Some(FeMessage::CopyData(msg))), + b'c' => Ok(Some(FeMessage::CopyDone)), + b'f' => Ok(Some(FeMessage::CopyFail)), + b'p' => Ok(Some(FeMessage::PasswordMessage(msg))), + tag => { + return Err(ProtocolError::Protocol(format!( + "unknown message tag: {tag},'{msg:?}'" + ))) + } + } + } + /// Read one message from the stream. /// This function returns `Ok(None)` in case of EOF. /// One way to handle this properly: @@ -245,68 +319,8 @@ impl FeMessage { /// } /// ``` #[inline(never)] - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { - Self::read_fut(&mut AsyncishRead(stream)).wait() - } - - /// Read one message from the stream. - /// See documentation for `Self::read`. - pub fn read_fut( - stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { - // We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof. - // SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and - // AsyncReadExt methods of the stream. - SyncFuture::new(async move { - // Each libpq message begins with a message type byte, followed by message length - // If the client closes the connection, return None. But if the client closes the - // connection in the middle of a message, we will return an error. - let tag = match retry_read!(stream.read_u8().await) { - Ok(b) => b, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; - - // The message length includes itself, so it better be at least 4. - let len = retry_read!(stream.read_u32().await) - .map_err(ConnectionError::Socket)? - .checked_sub(4) - .ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?; - - let body = { - let mut buffer = vec![0u8; len as usize]; - stream - .read_exact(&mut buffer) - .await - .map_err(ConnectionError::Socket)?; - Bytes::from(buffer) - }; - - match tag { - b'Q' => Ok(Some(FeMessage::Query(body))), - b'P' => Ok(Some(FeParseMessage::parse(body)?)), - b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), - b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), - b'B' => Ok(Some(FeBindMessage::parse(body)?)), - b'C' => Ok(Some(FeCloseMessage::parse(body)?)), - b'S' => Ok(Some(FeMessage::Sync)), - b'X' => Ok(Some(FeMessage::Terminate)), - b'd' => Ok(Some(FeMessage::CopyData(body))), - b'c' => Ok(Some(FeMessage::CopyDone)), - b'f' => Ok(Some(FeMessage::CopyFail)), - b'p' => Ok(Some(FeMessage::PasswordMessage(body))), - tag => { - return Err(ConnectionError::Protocol(format!( - "unknown message tag: {tag},'{body:?}'" - ))) - } - } - }) + pub fn read(_stream: &mut (impl io::Read + Unpin)) -> Result, ProtocolError> { + Ok(None) // removeme } } @@ -314,21 +328,124 @@ impl FeStartupPacket { /// Read startup message from the stream. // XXX: It's tempting yet undesirable to accept `stream` by value, // since such a change will cause user-supplied &mut references to be consumed - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { + pub fn read(stream: &mut (impl io::Read + Unpin)) -> Result, ProtocolError> { Self::read_fut(&mut AsyncishRead(stream)).wait() } + /// Read and parse startup message from the `buf` input buffer. It is + /// different from [`FeMessage::parse`] because startup messages don't have + /// message type byte; otherwise, its comments apply. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; + const CANCEL_REQUEST_CODE: u32 = 5678; + const NEGOTIATE_SSL_CODE: u32 = 5679; + const NEGOTIATE_GSS_CODE: u32 = 5680; + + if buf.len() < 4 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let len = (&buf[0..4]).read_u32::().unwrap() as usize; + if len < 8 || len > MAX_STARTUP_PACKET_LENGTH { + return Err(ProtocolError::Protocol(format!( + "invalid startup packet message length {}", + len + ))); + } + + if buf.len() < len { + // Don't have full message yet. + let to_read = len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + // got the message, advance buffer + let mut msg = buf.split_to(len).freeze(); + msg.advance(4); // consume len + + let request_code = msg.get_u32(); + let req_hi = request_code >> 16; + let req_lo = request_code & ((1 << 16) - 1); + // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code. + let message = match (req_hi, req_lo) { + (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { + if msg.remaining() < 8 { + return Err(ProtocolError::MessageParse(anyhow!( + "CancelRequest message is malformed, backend PID / secret key missing" + ))); + } + FeStartupPacket::CancelRequest(CancelKeyData { + backend_pid: msg.get_i32(), + cancel_key: msg.get_i32(), + }) + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { + // Requested upgrade to SSL (aka TLS) + FeStartupPacket::SslRequest + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { + // Requested upgrade to GSSAPI + FeStartupPacket::GssEncRequest + } + (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { + return Err(ProtocolError::Protocol(format!( + "Unrecognized request code {unrecognized_code}" + ))); + } + // TODO bail if protocol major_version is not 3? + (major_version, minor_version) => { + // StartupMessage + + // Parse pairs of null-terminated strings (key, value). + // See `postgres: ProcessStartupPacket, build_startup_packet`. + let mut tokens = str::from_utf8(&msg) + .context("StartupMessage params: invalid utf-8")? + .strip_suffix('\0') // drop packet's own null + .ok_or_else(|| { + ProtocolError::Protocol( + "StartupMessage params: missing null terminator".to_string(), + ) + })? + .split_terminator('\0'); + + let mut params = HashMap::new(); + while let Some(name) = tokens.next() { + let value = tokens.next().ok_or_else(|| { + ProtocolError::Protocol( + "StartupMessage params: key without value".to_string(), + ) + })?; + + params.insert(name.to_owned(), value.to_owned()); + } + + FeStartupPacket::StartupMessage { + major_version, + minor_version, + params: StartupMessageParams { params }, + } + } + }; + Ok(Some(FeMessage::StartupPacket(message))) + } + /// Read startup message from the stream. // XXX: It's tempting yet undesirable to accept `stream` by value, // since such a change will cause user-supplied &mut references to be consumed pub fn read_fut( stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> + ) -> SyncFuture, ProtocolError>> + '_> where Reader: tokio::io::AsyncRead + Unpin, { + use tokio::io::AsyncReadExt; + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; const CANCEL_REQUEST_CODE: u32 = 5678; @@ -343,18 +460,18 @@ impl FeStartupPacket { let len = match retry_read!(stream.read_u32().await) { Ok(len) => len as usize, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), + Err(e) => return Err(ProtocolError::Socket(e)), }; #[allow(clippy::manual_range_contains)] if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { - return Err(ConnectionError::Protocol(format!( + return Err(ProtocolError::Protocol(format!( "invalid message length {len}" ))); } let request_code = - retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?; + retry_read!(stream.read_u32().await).map_err(ProtocolError::Socket)?; // the rest of startup packet are params let params_len = len - 8; @@ -362,7 +479,7 @@ impl FeStartupPacket { stream .read_exact(params_bytes.as_mut()) .await - .map_err(ConnectionError::Socket)?; + .map_err(ProtocolError::Socket)?; // Parse params depending on request code let req_hi = request_code >> 16; @@ -370,14 +487,16 @@ impl FeStartupPacket { let message = match (req_hi, req_lo) { (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { if params_len != 8 { - return Err(ConnectionError::Protocol( + return Err(ProtocolError::Protocol( "expected 8 bytes for CancelRequest params".to_string(), )); } let mut cursor = Cursor::new(params_bytes); FeStartupPacket::CancelRequest(CancelKeyData { - backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, + backend_pid: 2, + cancel_key: 2, + // backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, + // cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, }) } (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { @@ -389,7 +508,7 @@ impl FeStartupPacket { FeStartupPacket::GssEncRequest } (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { - return Err(ConnectionError::Protocol(format!( + return Err(ProtocolError::Protocol(format!( "Unrecognized request code {unrecognized_code}" ))); } @@ -401,7 +520,7 @@ impl FeStartupPacket { .context("StartupMessage params: invalid utf-8")? .strip_suffix('\0') // drop packet's own null .ok_or_else(|| { - ConnectionError::Protocol( + ProtocolError::Protocol( "StartupMessage params: missing null terminator".to_string(), ) })? @@ -410,7 +529,7 @@ impl FeStartupPacket { let mut params = HashMap::new(); while let Some(name) = tokens.next() { let value = tokens.next().ok_or_else(|| { - ConnectionError::Protocol( + ProtocolError::Protocol( "StartupMessage params: key without value".to_string(), ) })?; @@ -440,6 +559,9 @@ impl FeParseMessage { let _pstmt_name = read_cstr(&mut buf)?; let query_string = read_cstr(&mut buf)?; + if buf.remaining() < 2 { + bail!("Parse message is malformed, nparams missing"); + } let nparams = buf.get_i16(); ensure!(nparams == 0, "query params not implemented"); @@ -466,6 +588,9 @@ impl FeDescribeMessage { impl FeExecuteMessage { fn parse(mut buf: Bytes) -> anyhow::Result { let portal_name = read_cstr(&mut buf)?; + if buf.remaining() < 4 { + bail!("FeExecuteMessage message is malformed, maxrows missing"); + } let maxrows = buf.get_i32(); ensure!(portal_name.is_empty(), "named portals not implemented"); @@ -547,6 +672,11 @@ impl<'a> BeMessage<'a> { value: b"UTF8", }; + pub const INTEGER_DATETIMES: Self = Self::ParameterStatus { + name: b"integer_datetimes", + value: b"on", + }; + /// Build a [`BeMessage::ParameterStatus`] holding the server version. pub fn server_version(version: &'a str) -> Self { Self::ParameterStatus { @@ -665,13 +795,12 @@ fn write_body(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R { } /// Safe write of s into buf as cstring (String in the protocol). -fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { +fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> { let bytes = s.as_ref(); if bytes.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", - )); + return Err(ProtocolError::MessageParse(anyhow!( + "string contains embedded null" + ))); } buf.put_slice(bytes); buf.put_u8(0); @@ -680,7 +809,7 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { fn read_cstr(buf: &mut Bytes) -> anyhow::Result { let pos = buf.iter().position(|x| *x == 0); - let result = buf.split_to(pos.context("missing terminator")?); + let result = buf.split_to(pos.context("missing cstring terminator")?); buf.advance(1); // drop the null terminator Ok(result) } @@ -688,12 +817,12 @@ fn read_cstr(buf: &mut Bytes) -> anyhow::Result { pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000"; impl<'a> BeMessage<'a> { - /// Write message to the given buf. - // Unlike the reading side, we use BytesMut - // here as msg len precedes its body and it is handy to write it down first - // and then fill the length. With Write we would have to either calc it - // manually or have one more buffer. - pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> { + /// Serialize `message` to the given `buf`. + /// Apart from smart memory managemet, BytesMut is good here as msg len + /// precedes its body and it is handy to write it down first and then fill + /// the length. With Write we would have to either calc it manually or have + /// one more buffer. + pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> { match message { BeMessage::AuthenticationOk => { buf.put_u8(b'R'); @@ -719,7 +848,7 @@ impl<'a> BeMessage<'a> { BeMessage::AuthenticationSasl(msg) => { buf.put_u8(b'R'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { use BeAuthenticationSaslMessage::*; match msg { Methods(methods) => { @@ -738,7 +867,7 @@ impl<'a> BeMessage<'a> { buf.put_slice(extra); } } - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -829,7 +958,7 @@ impl<'a> BeMessage<'a> { BeMessage::ErrorResponse(error_msg, pg_error_code) => { // 'E' signalizes ErrorResponse messages buf.put_u8(b'E'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_u8(b'S'); // severity buf.put_slice(b"ERROR\0"); @@ -842,7 +971,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg, buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -854,7 +983,7 @@ impl<'a> BeMessage<'a> { // 'N' signalizes NoticeResponse messages buf.put_u8(b'N'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_u8(b'S'); // severity buf.put_slice(b"NOTICE\0"); @@ -865,7 +994,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg.as_bytes(), buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -909,7 +1038,7 @@ impl<'a> BeMessage<'a> { BeMessage::RowDescription(rows) => { buf.put_u8(b'T'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_i16(rows.len() as i16); // # of fields for row in rows.iter() { write_cstr(row.name, buf)?; @@ -920,7 +1049,7 @@ impl<'a> BeMessage<'a> { buf.put_i32(-1); /* typmod */ buf.put_i16(0); /* format code */ } - Ok::<_, io::Error>(()) + Ok(()) })?; } diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 1091a8bd5c..901f849801 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static { } pub struct Download { - pub download_stream: Pin>, + pub download_stream: Pin>, /// Extra key-value data, associated with the current remote file. pub metadata: Option, } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 9c7fcafe23..61f1dc93f5 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -10,13 +10,16 @@ async-trait = "0.1" anyhow = "1.0" bincode = "1.3" bytes = "1.0.1" +futures = "0.3" hyper = { version = "0.14.7", features = ["full"] } +pin-utils = "0.1" routerify = "3" serde = { version = "1.0", features = ["derive"] } serde_json = "1" thiserror = "1.0" tokio = { version = "1.17", features = ["macros"]} tokio-rustls = "0.23" +tokio-util = { version = "0.7.3" } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } nix = "0.25" diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 6d35fd9f7b..ee228a7f1c 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -13,7 +13,7 @@ pub mod simple_rcu; pub mod vec_map; pub mod bin_ser; -pub mod postgres_backend; +// pub mod postgres_backend; pub mod postgres_backend_async; // helper functions for creating and fsyncing @@ -52,6 +52,8 @@ pub mod signals; pub mod fs_ext; +pub mod send_rc; + /// use with fail::cfg("$name", "return(2000)") #[macro_export] macro_rules! failpoint_sleep_millis_async { diff --git a/libs/utils/src/postgres_backend_async.rs b/libs/utils/src/postgres_backend_async.rs index 95b7b3fd15..dd27d911b3 100644 --- a/libs/utils/src/postgres_backend_async.rs +++ b/libs/utils/src/postgres_backend_async.rs @@ -2,29 +2,24 @@ //! To use, create PostgresBackend and run() it, passing the Handler //! implementation determining how to process the queries. Currently its API //! is rather narrow, but we can extend it once required. - -use crate::postgres_backend::AuthType; use anyhow::Context; use bytes::{Buf, Bytes, BytesMut}; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR}; -use std::future::Future; -use std::io; +use futures::stream::StreamExt; +use futures::{pin_mut, Sink, SinkExt}; +use serde::{Deserialize, Serialize}; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; -use tracing::{debug, error, info, trace}; - +use std::{fmt, io}; +use std::{future::Future, str::FromStr}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio_rustls::TlsAcceptor; +use tokio_util::codec::Framed; +use tracing::{debug, error, info, trace}; -pub fn is_expected_io_error(e: &io::Error) -> bool { - use io::ErrorKind::*; - matches!( - e.kind(), - ConnectionRefused | ConnectionAborted | ConnectionReset - ) -} +use pq_proto::codec::{ConnectionError, PostgresCodec}; +use pq_proto::{BeMessage, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR}; /// An error, occurred during query processing: /// either during the connection ([`ConnectionError`]) or before/after it. @@ -40,7 +35,7 @@ pub enum QueryError { impl From for QueryError { fn from(e: io::Error) -> Self { - Self::Disconnected(ConnectionError::Socket(e)) + Self::Disconnected(ConnectionError::Io(e)) } } @@ -53,6 +48,14 @@ impl QueryError { } } +pub fn is_expected_io_error(e: &io::Error) -> bool { + use io::ErrorKind::*; + matches!( + e.kind(), + ConnectionRefused | ConnectionAborted | ConnectionReset + ) +} + #[async_trait::async_trait] pub trait Handler { /// Handle single query. @@ -93,6 +96,7 @@ pub trait Handler { #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] pub enum ProtoState { Initialization, + // Encryption handshake is done; waiting for encrypted Startup message. Encrypted, Authentication, Established, @@ -105,15 +109,14 @@ pub enum ProcessMsgResult { Break, } -/// Always-writeable sock_split stream. -/// May not be readable. See [`PostgresBackend::take_stream_in`] -pub enum Stream { - Unencrypted(BufReader), - Tls(Box>>), - Broken, +/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite. +pub enum MaybeTlsStream { + Unencrypted(tokio::net::TcpStream), + Tls(Box>), + Broken, // temporary value for switch to TLS } -impl AsyncWrite for Stream { +impl AsyncWrite for MaybeTlsStream { fn poll_write( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -122,14 +125,14 @@ impl AsyncWrite for Stream { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), - Self::Broken => unreachable!(), + _ => unreachable!(), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx), Self::Tls(stream) => Pin::new(stream).poll_flush(cx), - Self::Broken => unreachable!(), + _ => unreachable!(), } } fn poll_shutdown( @@ -139,11 +142,11 @@ impl AsyncWrite for Stream { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx), Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), - Self::Broken => unreachable!(), + _ => unreachable!(), } } } -impl AsyncRead for Stream { +impl AsyncRead for MaybeTlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -152,18 +155,49 @@ impl AsyncRead for Stream { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), - Self::Broken => unreachable!(), + _ => unreachable!(), } } } -pub struct PostgresBackend { - stream: Stream, +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum AuthType { + Trust, + // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT + NeonJWT, +} - // Output buffer. c.f. BeMessage::write why we are using BytesMut here. - // The data between 0 and "current position" as tracked by the bytes::Buf - // implementation of BytesMut, have already been written. - buf_out: BytesMut, +impl FromStr for AuthType { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "Trust" => Ok(Self::Trust), + "NeonJWT" => Ok(Self::NeonJWT), + _ => anyhow::bail!("invalid value \"{s}\" for auth type"), + } + } +} + +impl fmt::Display for AuthType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + AuthType::Trust => "Trust", + AuthType::NeonJWT => "NeonJWT", + }) + } +} + +pub struct PostgresBackend { + // Provides serialization/deserialization to the underlying transport backed + // with buffers; implements Sink consuming messages and Stream reading them. + // + // Sink::start_send only queues message to the interal buffer. + // SinkExt::flush flushes buffer to the stream. + // + // StreamExt::read reads next message. In case of EOF without partial + // message it returns None. + stream: Framed, pub state: ProtoState, @@ -196,10 +230,10 @@ impl PostgresBackend { tls_config: Option>, ) -> io::Result { let peer_addr = socket.peer_addr()?; + let stream = MaybeTlsStream::Unencrypted(socket); Ok(Self { - stream: Stream::Unencrypted(BufReader::new(socket)), - buf_out: BytesMut::with_capacity(10 * 1024), + stream: Framed::new(stream, PostgresCodec::new()), state: ProtoState::Initialization, auth_type, tls_config, @@ -212,29 +246,60 @@ impl PostgresBackend { } /// Read full message or return None if connection is closed. - pub async fn read_message(&mut self) -> Result, QueryError> { - use ProtoState::*; - match self.state { - Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await, - Authentication | Established => FeMessage::read_fut(&mut self.stream).await, - Closed => Ok(None), + pub async fn read_message(&mut self) -> Result, ConnectionError> { + if let ProtoState::Closed = self.state { + Ok(None) + } else { + let msg = self.stream.next().await; + // Option>, so swap. + msg.map_or(Ok(None), |res| res.map(Some)) } - .map_err(QueryError::from) + } + + /// Polling version of read_message, saves the caller need to pin. + pub fn poll_read_message( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, ConnectionError>> { + let read_fut = self.read_message(); + pin_mut!(read_fut); + read_fut.poll(cx) } /// Flush output buffer into the socket. pub async fn flush(&mut self) -> io::Result<()> { - while self.buf_out.has_remaining() { - let bytes_written = self.stream.write(self.buf_out.chunk()).await?; - self.buf_out.advance(bytes_written); - } - self.buf_out.clear(); - Ok(()) + self.stream.flush().await.map_err(|e| match e { + ConnectionError::Io(e) => e, + // the only error we can get from flushing is IO + _ => unreachable!(), + }) } - /// Write message into internal output buffer. - pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buf_out, message)?; + /// Polling version of `flush()`, saves the caller need to pin. + pub fn poll_flush( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let flush_fut = self.flush(); + pin_mut!(flush_fut); + flush_fut.poll(cx) + } + + /// Write message into internal output buffer. Technically error type can be + /// only ProtocolError here (if, unlikely, serialization fails), but callers + /// typically wrap it anyway. + pub fn write_message(&mut self, message: &BeMessage<'_>) -> Result<&mut Self, ConnectionError> { + Pin::new(&mut self.stream).start_send(message)?; + Ok(self) + } + + /// Write message into internal output buffer and flush it to the stream. + pub async fn write_message_flush( + &mut self, + message: &BeMessage<'_>, + ) -> Result<&mut Self, ConnectionError> { + self.write_message(message)?; + self.flush().await?; Ok(self) } @@ -246,28 +311,6 @@ impl PostgresBackend { CopyDataWriter { pgb: self } } - /// A polling function that tries to write all the data from 'buf_out' to the - /// underlying stream. - fn poll_write_buf( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - while self.buf_out.has_remaining() { - match Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk()) { - Poll::Ready(Ok(bytes_written)) => { - self.buf_out.advance(bytes_written); - } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } - } - Poll::Ready(Ok(())) - } - - fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_flush(cx) - } - // Wrapper for run_message_loop() that shuts down socket when we are done pub async fn run( mut self, @@ -279,7 +322,7 @@ impl PostgresBackend { S: Future, { let ret = self.run_message_loop(handler, shutdown_watcher).await; - let _ = self.stream.shutdown(); + let _ = self.stream.get_mut().shutdown(); ret } @@ -359,14 +402,22 @@ impl PostgresBackend { } async fn start_tls(&mut self) -> anyhow::Result<()> { - if let Stream::Unencrypted(plain_stream) = - std::mem::replace(&mut self.stream, Stream::Broken) + if let MaybeTlsStream::Unencrypted(plain_stream) = + // temporary replace stream with fake broken to prepare TLS one + std::mem::replace(self.stream.get_mut(), MaybeTlsStream::Broken) { let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap()); - let tls_stream = acceptor.accept(plain_stream).await?; - - self.stream = Stream::Tls(Box::new(tls_stream)); - return Ok(()); + match acceptor.accept(plain_stream).await { + Ok(tls_stream) => { + // push back ready TLS stream + *self.stream.get_mut() = MaybeTlsStream::Tls(Box::new(tls_stream)); + return Ok(()); + } + Err(e) => { + self.state = ProtoState::Closed; + return Err(e.into()); + } + } }; anyhow::bail!("TLS already started"); } @@ -380,13 +431,12 @@ impl PostgresBackend { let have_tls = self.tls_config.is_some(); match msg { FeMessage::StartupPacket(m) => { - trace!("got startup message {m:?}"); - match m { FeStartupPacket::SslRequest => { debug!("SSL requested"); self.write_message(&BeMessage::EncryptionResponse(have_tls))?; + if have_tls { self.start_tls().await?; self.state = ProtoState::Encrypted; @@ -415,6 +465,7 @@ impl PostgresBackend { AuthType::Trust => { self.write_message(&BeMessage::AuthenticationOk)? .write_message(&BeMessage::CLIENT_ENCODING)? + .write_message(&BeMessage::INTEGER_DATETIMES)? // The async python driver requires a valid server_version .write_message(&BeMessage::server_version("14.1"))? .write_message(&BeMessage::ReadyForQuery)?; @@ -454,6 +505,7 @@ impl PostgresBackend { } self.write_message(&BeMessage::AuthenticationOk)? .write_message(&BeMessage::CLIENT_ENCODING)? + .write_message(&BeMessage::INTEGER_DATETIMES)? .write_message(&BeMessage::ReadyForQuery)?; self.state = ProtoState::Established; } @@ -573,7 +625,7 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> { // It's not strictly required to flush between each message, but makes it easier // to view in wireshark, and usually the messages that the callers write are // decently-sized anyway. - match this.pgb.poll_write_buf(cx) { + match this.pgb.poll_flush(cx) { Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, @@ -583,7 +635,11 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> { // XXX: if the input is large, we should split it into multiple messages. // Not sure what the threshold should be, but the ultimate hard limit is that // the length cannot exceed u32. - this.pgb.write_message(&BeMessage::CopyData(buf))?; + this.pgb + .write_message(&BeMessage::CopyData(buf)) + // write_message only writes to buffer, so can fail iff message is + // invaid, but CopyData can't be invalid. + .expect("failed to serialize CopyData"); Poll::Ready(Ok(buf.len())) } @@ -593,23 +649,14 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> { cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.get_mut(); - match this.pgb.poll_write_buf(cx) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } this.pgb.poll_flush(cx) } + fn poll_shutdown( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.get_mut(); - match this.pgb.poll_write_buf(cx) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } this.pgb.poll_flush(cx) } } @@ -623,7 +670,7 @@ pub fn short_error(e: &QueryError) -> String { pub(super) fn log_query_error(query: &str, e: &QueryError) { match e { - QueryError::Disconnected(ConnectionError::Socket(io_error)) => { + QueryError::Disconnected(ConnectionError::Io(io_error)) => { if is_expected_io_error(io_error) { info!("query handler for '{query}' failed with expected io error: {io_error}"); } else { diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 18ec1ac68b..3f7a4c01fd 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -24,7 +24,7 @@ use pageserver::{ use utils::{ auth::JwtAuth, logging, - postgres_backend::AuthType, + postgres_backend_async::AuthType, project_git_version, sentry_init::{init_sentry, release_name}, signals::{self, Signal}, diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 7b99d98581..342abb7bbc 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -24,7 +24,7 @@ use toml_edit::{Document, Item}; use utils::{ id::{NodeId, TenantId, TimelineId}, logging::LogFormat, - postgres_backend::AuthType, + postgres_backend_async::AuthType, }; use crate::tenant::config::TenantConf; diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index b266a07337..a24532d28c 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -19,7 +19,7 @@ use pageserver_api::models::{ PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamNblocksRequest, PagestreamNblocksResponse, }; -use pq_proto::ConnectionError; +use pq_proto::codec::ConnectionError; use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; use std::io; @@ -35,7 +35,7 @@ use utils::{ auth::{Claims, JwtAuth, Scope}, id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, + postgres_backend_async::AuthType, postgres_backend_async::{self, PostgresBackend}, simple_rcu::RcuReadGuard, }; @@ -67,7 +67,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { msg } + msg = pgb.read_message() => { msg.map_err(QueryError::from)} }; match msg { @@ -78,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream continue, FeMessage::Terminate => { let msg = "client terminated connection with Terminate message during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))) + .expect("failed to serialize ErrorResponse"); Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; break; } m => { let msg = format!("unexpected message {m:?}"); - pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?; + pgb.write_message(&BeMessage::ErrorResponse(&msg, None)) + .expect("failed to serialize ErrorResponse"); Err(io::Error::new(io::ErrorKind::Other, msg))?; break; } @@ -95,16 +97,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { let msg = "client closed connection during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))) + .expect("failed to serialize ErrorResponse"); pgb.flush().await?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { Err(io_error)?; } Err(other) => { - Err(io::Error::new(io::ErrorKind::Other, other))?; + Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?; } }; } @@ -202,7 +205,7 @@ async fn page_service_conn_main( // we've been requested to shut down Ok(()) } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { // `ConnectionReset` error happens when the Postgres client closes the connection. // As this disconnection happens quite often and is expected, // we decided to downgrade the logging level to `INFO`. diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 02a0fabe9a..d02a310b9c 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,7 +2,7 @@ use crate::error::UserFacingError; use anyhow::bail; use bytes::BytesMut; use pin_project_lite::pin_project; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket}; +use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; use std::pin::Pin; use std::sync::Arc; @@ -53,7 +53,7 @@ impl PqStream { // TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket` let msg = FeStartupPacket::read_fut(&mut self.stream) .await - .map_err(ConnectionError::into_io_error)? + .map_err(ProtocolError::into_io_error)? .ok_or_else(err_connection)?; match msg { @@ -75,7 +75,7 @@ impl PqStream { async fn read_message(&mut self) -> io::Result { FeMessage::read_fut(&mut self.stream) .await - .map_err(ConnectionError::into_io_error)? + .map_err(ProtocolError::into_io_error)? .ok_or_else(err_connection) } } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 7ee14a8f41..33e5af1160 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -4,7 +4,7 @@ # version, we can consider updating. # See https://tracker.debian.org/pkg/rustc for more details on Debian rustc package, # we use "unstable" version number as the highest version used in the project by default. -channel = "1.62.1" +channel = "1.66.1" profile = "default" # The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy. # https://rust-lang.github.io/rustup/concepts/profiles.html diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index d0c804fe4e..202aa6dac9 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -14,12 +14,14 @@ clap = { version = "4.0", features = ["derive"] } const_format = "0.2.21" crc32c = "0.6.0" fs2 = "0.4.3" +futures = "0.3" git-version = "0.3.5" hex = "0.4.3" humantime = "2.1.0" hyper = "0.14" nix = "0.25" once_cell = "1.13.0" +pin-project-lite = "0.2" parking_lot = "0.12.1" postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" } postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" } diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index b130ea86bd..2940ca99dd 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -228,20 +228,20 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { let conf_cloned = conf.clone(); let safekeeper_thread = thread::Builder::new() - .name("safekeeper thread".into()) + .name("WAL service thread".into()) .spawn(|| wal_service::thread_main(conf_cloned, pg_listener)) .unwrap(); threads.push(safekeeper_thread); let conf_ = conf.clone(); - threads.push( - thread::Builder::new() - .name("broker thread".into()) - .spawn(|| { - broker::thread_main(conf_); - })?, - ); + // threads.push( + // thread::Builder::new() + // .name("broker thread".into()) + // .spawn(|| { + // broker::thread_main(conf_); + // })?, + // ); let conf_ = conf.clone(); threads.push( diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index 60df5dd372..edfa911abb 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -1,27 +1,23 @@ //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres //! protocol commands. +use anyhow::Context; +use std::str; +use tracing::{info, info_span, Instrument}; + use crate::auth::check_permission; use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage}; use crate::receive_wal::ReceiveWalConn; - -use crate::send_wal::ReplicationConn; - use crate::{GlobalTimelines, SafeKeeperConf}; -use anyhow::Context; - use postgres_ffi::PG_TLI; -use regex::Regex; - use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID}; -use std::str; -use tracing::info; +use regex::Regex; use utils::auth::{Claims, Scope}; use utils::postgres_backend_async::QueryError; use utils::{ id::{TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, - postgres_backend::{self, PostgresBackend}, + postgres_backend_async::{self, PostgresBackend}, }; /// Safekeeper handler of postgres commands @@ -67,7 +63,8 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { } } -impl postgres_backend::Handler for SafekeeperPostgresHandler { +#[async_trait::async_trait] +impl postgres_backend_async::Handler for SafekeeperPostgresHandler { // tenant_id and timeline_id are passed in connection string params fn startup( &mut self, @@ -137,7 +134,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { Ok(()) } - fn process_query( + async fn process_query( &mut self, pgb: &mut PostgresBackend, query_string: &str, @@ -147,9 +144,14 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { .starts_with("set datestyle to ") { // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect - pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + pgb.write_message_flush(&BeMessage::CommandComplete(b"SELECT 1")) + .await?; return Ok(()); } + info!( + "got unparsed query {:?} in timeline {:?}", + query_string, self.timeline_id + ); let cmd = parse_cmd(query_string)?; info!( @@ -161,14 +163,20 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { let timeline_id = self.timeline_id.context("timelineid is required")?; self.check_permission(Some(tenant_id))?; self.ttid = TenantTimelineId::new(tenant_id, timeline_id); + let span_ttid = self.ttid; // satisfy borrow checker let res = match cmd { - SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self), + // SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self), + SafekeeperPostgresCommand::StartWalPush => Ok(()), SafekeeperPostgresCommand::StartReplication { start_lsn } => { - ReplicationConn::new(pgb).run(self, pgb, start_lsn) + self.handle_start_replication(pgb, start_lsn) + .instrument(info_span!("WAL sender", ttid = %span_ttid)) + .await + } + SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await, + SafekeeperPostgresCommand::JSONCtrl { ref cmd } => { + handle_json_ctrl(self, pgb, cmd).await } - SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb), - SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd), }; match res { @@ -217,7 +225,10 @@ impl SafekeeperPostgresHandler { /// /// Handle IDENTIFY_SYSTEM replication command /// - fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> { + async fn handle_identify_system( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), QueryError> { let tli = GlobalTimelines::get(self.ttid)?; let lsn = if self.is_walproposer_recovery() { @@ -235,7 +246,7 @@ impl SafekeeperPostgresHandler { let tli_bytes = tli.as_bytes(); let sysid_bytes = sysid.as_bytes(); - pgb.write_message_noflush(&BeMessage::RowDescription(&[ + pgb.write_message(&BeMessage::RowDescription(&[ RowDescriptor { name: b"systemid", typoid: TEXT_OID, @@ -261,13 +272,14 @@ impl SafekeeperPostgresHandler { ..Default::default() }, ]))? - .write_message_noflush(&BeMessage::DataRow(&[ + .write_message(&BeMessage::DataRow(&[ Some(sysid_bytes), Some(tli_bytes), Some(lsn_bytes), None, ]))? - .write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; + .write_message_flush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM")) + .await?; Ok(()) } diff --git a/safekeeper/src/json_ctrl.rs b/safekeeper/src/json_ctrl.rs index 32a24a4978..f80ab5954b 100644 --- a/safekeeper/src/json_ctrl.rs +++ b/safekeeper/src/json_ctrl.rs @@ -26,7 +26,7 @@ use crate::GlobalTimelines; use postgres_ffi::encode_logical_message; use postgres_ffi::WAL_SEGMENT_SIZE; use pq_proto::{BeMessage, RowDescriptor, TEXT_OID}; -use utils::{lsn::Lsn, postgres_backend::PostgresBackend}; +use utils::{lsn::Lsn, postgres_backend_async::PostgresBackend}; #[derive(Serialize, Deserialize, Debug)] pub struct AppendLogicalMessage { @@ -59,7 +59,7 @@ struct AppendResult { /// Handles command to craft logical message WAL record with given /// content, and then append it with specified term and lsn. This /// function is used to test safekeepers in different scenarios. -pub fn handle_json_ctrl( +pub async fn handle_json_ctrl( spg: &SafekeeperPostgresHandler, pgb: &mut PostgresBackend, append_request: &AppendLogicalMessage, @@ -82,14 +82,15 @@ pub fn handle_json_ctrl( let response_data = serde_json::to_vec(&response) .with_context(|| format!("Response {response:?} is not a json array"))?; - pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor { + pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor { name: b"json", typoid: TEXT_OID, typlen: -1, ..Default::default() }]))? - .write_message_noflush(&BeMessage::DataRow(&[Some(&response_data)]))? - .write_message(&BeMessage::CommandComplete(b"JSON_CTRL"))?; + .write_message(&BeMessage::DataRow(&[Some(&response_data)]))? + .write_message_flush(&BeMessage::CommandComplete(b"JSON_CTRL")) + .await?; Ok(()) } diff --git a/safekeeper/src/receive_wal.rs b/safekeeper/src/receive_wal.rs index 671e5470a0..b96aa53147 100644 --- a/safekeeper/src/receive_wal.rs +++ b/safekeeper/src/receive_wal.rs @@ -26,7 +26,7 @@ use crate::safekeeper::ProposerAcceptorMessage; use crate::handler::SafekeeperPostgresHandler; use pq_proto::{BeMessage, FeMessage}; -use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream}; +use utils::{postgres_backend_async::PostgresBackend, sock_split::ReadStream}; pub struct ReceiveWalConn<'pg> { /// Postgres connection @@ -59,82 +59,83 @@ impl<'pg> ReceiveWalConn<'pg> { // Notify the libpq client that it's allowed to send `CopyData` messages self.pg_backend .write_message(&BeMessage::CopyBothResponse)?; - - let r = self - .pg_backend - .take_stream_in() - .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?; - let mut poll_reader = ProposerPollStream::new(r)?; + Ok(()) + // let r = self + // .pg_backend + // .take_stream_in() + // .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?; + // let mut poll_reader = ProposerPollStream::new(r)?; // Receive information about server - let next_msg = poll_reader.recv_msg()?; - let tli = match next_msg { - ProposerAcceptorMessage::Greeting(ref greeting) => { - info!( - "start handshake with walproposer {} sysid {} timeline {}", - self.peer_addr, greeting.system_id, greeting.tli, - ); - let server_info = ServerInfo { - pg_version: greeting.pg_version, - system_id: greeting.system_id, - wal_seg_size: greeting.wal_seg_size, - }; - GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)? - } - _ => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message {next_msg:?} instead of greeting" - ))) - } - }; + // let next_msg = poll_reader.recv_msg()?; + // let tli = match next_msg { + // ProposerAcceptorMessage::Greeting(ref greeting) => { + // info!( + // "start handshake with walproposer {} sysid {} timeline {}", + // self.peer_addr, greeting.system_id, greeting.tli, + // ); + // let server_info = ServerInfo { + // pg_version: greeting.pg_version, + // system_id: greeting.system_id, + // wal_seg_size: greeting.wal_seg_size, + // }; + // GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)? + // } + // _ => { + // return Err(QueryError::Other(anyhow::anyhow!( + // "unexpected message {next_msg:?} instead of greeting" + // ))) + // } + // }; - let mut next_msg = Some(next_msg); + // let mut next_msg = None; - let mut first_time_through = true; - let mut _guard: Option = None; - loop { - if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) { - // poll AppendRequest's without blocking and write WAL to disk without flushing, - // while it's readily available - while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg { - let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); + // let mut first_time_through = true; + // let mut _guard: Option = None; + // loop { + // if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) { + // // poll AppendRequest's without blocking and write WAL to disk without flushing, + // // while it's readily available + // while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg { + // let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } + // let reply = tli.process_msg(&msg)?; + // if let Some(reply) = reply { + // self.write_msg(&reply)?; + // } - next_msg = poll_reader.poll_msg(); - } + // // next_msg = poll_reader.poll_msg(); + // next_msg = poll_reader.poll_msg(); + // } - // flush all written WAL to the disk - let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } else if let Some(msg) = next_msg.take() { - // process other message - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } - if first_time_through { - // Register the connection and defer unregister. Do that only - // after processing first message, as it sets wal_seg_size, - // wanted by many. - tli.on_compute_connect()?; - _guard = Some(ComputeConnectionGuard { - timeline: Arc::clone(&tli), - }); - first_time_through = false; - } + // // flush all written WAL to the disk + // let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?; + // if let Some(reply) = reply { + // self.write_msg(&reply)?; + // } + // } else if let Some(msg) = next_msg.take() { + // // process other message + // let reply = tli.process_msg(&msg)?; + // if let Some(reply) = reply { + // self.write_msg(&reply)?; + // } + // } + // if first_time_through { + // // Register the connection and defer unregister. Do that only + // // after processing first message, as it sets wal_seg_size, + // // wanted by many. + // tli.on_compute_connect()?; + // _guard = Some(ComputeConnectionGuard { + // timeline: Arc::clone(&tli), + // }); + // first_time_through = false; + // } - // blocking wait for the next message - if next_msg.is_none() { - next_msg = Some(poll_reader.recv_msg()?); - } - } + // // blocking wait for the next message + // if next_msg.is_none() { + // next_msg = Some(poll_reader.recv_msg()?); + // } + // } } } @@ -144,37 +145,37 @@ struct ProposerPollStream { } impl ProposerPollStream { - fn new(mut r: ReadStream) -> anyhow::Result { - let (msg_tx, msg_rx) = channel(); + // fn new(mut r: ReadStream) -> anyhow::Result { + // let (msg_tx, msg_rx) = channel(); - let read_thread = thread::Builder::new() - .name("Read WAL thread".into()) - .spawn(move || -> Result<(), QueryError> { - loop { - let copy_data = match FeMessage::read(&mut r)? { - Some(FeMessage::CopyData(bytes)) => Ok(bytes), - Some(msg) => Err(QueryError::Other(anyhow::anyhow!( - "expected `CopyData` message, found {msg:?}" - ))), - None => Err(QueryError::from(std::io::Error::new( - std::io::ErrorKind::ConnectionAborted, - "walproposer closed the connection", - ))), - }?; + // let read_thread = thread::Builder::new() + // .name("Read WAL thread".into()) + // .spawn(move || -> Result<(), QueryError> { + // loop { + // let copy_data = match FeMessage::read(&mut r)? { + // Some(FeMessage::CopyData(bytes)) => Ok(bytes), + // Some(msg) => Err(QueryError::Other(anyhow::anyhow!( + // "expected `CopyData` message, found {msg:?}" + // ))), + // None => Err(QueryError::from(std::io::Error::new( + // std::io::ErrorKind::ConnectionAborted, + // "walproposer closed the connection", + // ))), + // }?; - let msg = ProposerAcceptorMessage::parse(copy_data)?; - msg_tx - .send(msg) - .context("Failed to send the proposer message")?; - } - // msg_tx will be dropped here, this will also close msg_rx - })?; + // let msg = ProposerAcceptorMessage::parse(copy_data)?; + // msg_tx + // .send(msg) + // .context("Failed to send the proposer message")?; + // } + // // msg_tx will be dropped here, this will also close msg_rx + // })?; - Ok(Self { - msg_rx, - read_thread: Some(read_thread), - }) - } + // Ok(Self { + // msg_rx, + // read_thread: Some(read_thread), + // }) + // } fn recv_msg(&mut self) -> Result { self.msg_rx.recv().map_err(|_| { diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 20600ab694..bf46139e22 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -1,28 +1,35 @@ //! This module implements the streaming side of replication protocol, starting //! with the "START_REPLICATION" message. +use anyhow::Context as AnyhowContext; +use bytes::Bytes; +use futures::future::BoxFuture; +use postgres_ffi::get_current_timestamp; +use postgres_ffi::{TimestampTz, MAX_SEND_SIZE}; +use serde::{Deserialize, Serialize}; +use std::cell::RefCell; +use std::cmp::min; +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; +use std::time::Duration; +use std::{io, str, thread}; +use tokio::sync::watch::Receiver; +use tokio::time::timeout; +use tracing::*; +use utils::postgres_backend_async::QueryError; +use utils::send_rc::RefCellSend; +use utils::send_rc::SendRc; + +use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; +use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend_async::PostgresBackend}; + use crate::handler::SafekeeperPostgresHandler; use crate::timeline::{ReplicaState, Timeline}; use crate::wal_storage::WalReader; use crate::GlobalTimelines; -use anyhow::Context; - -use bytes::Bytes; -use postgres_ffi::get_current_timestamp; -use postgres_ffi::{TimestampTz, MAX_SEND_SIZE}; -use serde::{Deserialize, Serialize}; -use std::cmp::min; -use std::net::Shutdown; -use std::sync::Arc; -use std::time::Duration; -use std::{io, str, thread}; -use utils::postgres_backend_async::QueryError; - -use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; -use tokio::sync::watch::Receiver; -use tokio::time::timeout; -use tracing::*; -use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream}; // See: https://www.postgresql.org/docs/13/protocol-replication.html const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h'; @@ -60,13 +67,6 @@ pub struct StandbyReply { pub reply_requested: bool, } -/// A network connection that's speaking the replication protocol. -pub struct ReplicationConn { - /// This is an `Option` because we will spawn a background thread that will - /// `take` it from us. - stream_in: Option, -} - /// Scope guard to unregister replication connection from timeline struct ReplicationConnGuard { replica: usize, // replica internal ID assigned by timeline @@ -79,230 +79,330 @@ impl Drop for ReplicationConnGuard { } } -impl ReplicationConn { - /// Create a new `ReplicationConn` - pub fn new(pgb: &mut PostgresBackend) -> Self { - Self { - stream_in: pgb.take_stream_in(), - } - } - - /// Handle incoming messages from the network. - /// This is spawned into the background by `handle_start_replication`. - fn background_thread( - mut stream_in: ReadStream, - replica_guard: Arc, - ) -> anyhow::Result<()> { - let replica_id = replica_guard.replica; - let timeline = &replica_guard.timeline; - - let mut state = ReplicaState::new(); - // Wait for replica's feedback. - while let Some(msg) = FeMessage::read(&mut stream_in)? { - match &msg { - FeMessage::CopyData(m) => { - // There's three possible data messages that the client is supposed to send here: - // `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`. - - match m.first().cloned() { - Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { - // Note: deserializing is on m[1..] because we skip the tag byte. - state.hs_feedback = HotStandbyFeedback::des(&m[1..]) - .context("failed to deserialize HotStandbyFeedback")?; - timeline.update_replica_state(replica_id, state); - } - Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { - let _reply = StandbyReply::des(&m[1..]) - .context("failed to deserialize StandbyReply")?; - // This must be a regular postgres replica, - // because pageserver doesn't send this type of messages to safekeeper. - // Currently this is not implemented, so this message is ignored. - - warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet."); - // timeline.update_replica_state(replica_id, Some(state)); - } - Some(NEON_STATUS_UPDATE_TAG_BYTE) => { - // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. - let buf = Bytes::copy_from_slice(&m[9..]); - let reply = ReplicationFeedback::parse(buf); - - trace!("ReplicationFeedback is {:?}", reply); - // Only pageserver sends ReplicationFeedback, so set the flag. - // This replica is the source of information to resend to compute. - state.pageserver_feedback = Some(reply); - - timeline.update_replica_state(replica_id, state); - } - _ => warn!("unexpected message {:?}", msg), - } - } - FeMessage::Sync => {} - FeMessage::CopyFail => { - // Shutdown the connection, because rust-postgres client cannot be dropped - // when connection is alive. - let _ = stream_in.shutdown(Shutdown::Both); - anyhow::bail!("Copy failed"); - } - _ => { - // We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored. - info!("unexpected message {:?}", msg); - } - } - } - - Ok(()) - } - - /// - /// Handle START_REPLICATION replication command - /// - pub fn run( +impl SafekeeperPostgresHandler { + pub async fn handle_start_replication( &mut self, - spg: &mut SafekeeperPostgresHandler, pgb: &mut PostgresBackend, - mut start_pos: Lsn, + start_pos: Lsn, ) -> Result<(), QueryError> { - let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered(); - - let tli = GlobalTimelines::get(spg.ttid)?; - - // spawn the background thread which receives HotStandbyFeedback messages. - let bg_timeline = Arc::clone(&tli); - let bg_stream_in = self.stream_in.take().unwrap(); - let bg_timeline_id = spg.timeline_id.unwrap(); + let appname = self.appname.clone(); + let tli = GlobalTimelines::get(self.ttid)?; let state = ReplicaState::new(); // This replica_id is used below to check if it's time to stop replication. - let replica_id = bg_timeline.add_replica(state); + let replica_id = tli.add_replica(state); // Use a guard object to remove our entry from the timeline, when the background // thread and us have both finished using it. - let replica_guard = Arc::new(ReplicationConnGuard { + let _guard = Arc::new(ReplicationConnGuard { replica: replica_id, - timeline: bg_timeline, + timeline: tli.clone(), }); - let bg_replica_guard = Arc::clone(&replica_guard); - // TODO: here we got two threads, one for writing WAL and one for receiving - // feedback. If one of them fails, we should shutdown the other one too. - let _ = thread::Builder::new() - .name("HotStandbyFeedback thread".into()) - .spawn(move || { - let _enter = - info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered(); - if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) { - error!("Replication background thread failed: {}", err); + // Walproposer gets special handling: safekeeper must give proposer all + // local WAL till the end, whether committed or not (walproposer will + // hang otherwise). That's because walproposer runs the consensus and + // synchronizes safekeepers on the most advanced one. + // + // There is a small risk of this WAL getting concurrently garbaged if + // another compute rises which collects majority and starts fixing log + // on this safekeeper itself. That's ok as (old) proposer will never be + // able to commit such WAL. + let stop_pos: Option = if self.is_walproposer_recovery() { + let wal_end = tli.get_flush_lsn(); + Some(wal_end) + } else { + None + }; + let end_pos = stop_pos.unwrap_or(Lsn::INVALID); + + info!( + "starting streaming from {:?} till {:?}", + start_pos, stop_pos + ); + + // switch to copy + pgb.write_message(&BeMessage::CopyBothResponse)?; + + let (_, persisted_state) = tli.get_state(); + let wal_reader = WalReader::new( + self.conf.workdir.clone(), + self.conf.timeline_dir(&tli.ttid), + &persisted_state, + start_pos, + self.conf.wal_backup_enabled, + )?; + let write_ctx = SendRc::new(WriteContext { + wal_reader: RefCell::new(wal_reader), + send_buf: RefCell::new([0; MAX_SEND_SIZE]), + }); + + ReplicationHandler { + tli, + replica_id, + appname, + pgb, + start_pos, + end_pos, + stop_pos, + write_ctx, + // Actually we start from reading WAL, but this way is easier to + // code, we'll just immediately switch. + write_state: WriteState::FlushWal, + feedback: ReplicaState::new(), + } + .await + } +} + +/// START_REPLICATION stream driver: sends WAL and receives feedback. +struct ReplicationHandler<'a> { + tli: Arc, + appname: Option, + replica_id: usize, + pgb: &'a mut PostgresBackend, + // Position since which we are sending next chunk. + start_pos: Lsn, + // WAL up to this position is known to be locally available. + end_pos: Lsn, + // If present, terminate after reaching this position; used by walproposer + // in recovery. + stop_pos: Option, + // This data is needed to create Future sending WAL, so we need to both + // have it here (to create new future) and borrow it to the future + // itself. Essentially this is a self referential struct. To satisfy + // borrow checker, use Rc. To make ReplicationHandler itself + // Send'able future, wrap it into SendRc; this is safe as + // ReplicationHandler is passed between threads only as a whole (during + // rescheduling). + // + // Right now we're in CurrentThread runtime, so Send is somewhat + // redundant; however, we'd need to inconveniently have separate !Send + // version of pg backend Handler trait (and work with LocalSet). + write_ctx: SendRc, + write_state: WriteState, + feedback: ReplicaState, +} + +// State which ReplicationHandler needs to create futures sending data. +struct WriteContext { + wal_reader: RefCell, + // buffer for readling WAL into to send it + send_buf: RefCell<[u8; MAX_SEND_SIZE]>, +} + +// Yield points of WAL sending machinery. +enum WriteState { + // TODO: see if we can remove boxing here; with anon type of async fn this + // is untrivial (+ needs fiddling with pinning, pin_project and replace). + WaitWal(BoxFuture<'static, anyhow::Result>>), + ReadWal(BoxFuture<'static, anyhow::Result>), + FlushWal, +} + +impl Future for ReplicationHandler<'_> { + type Output = Result<(), QueryError>; + + // We need to read feedback from the socket and write data there at the same + // time. To avoid having to split socket, which creates messy split-join + // APIs, is problematic with TLS [1] and needs to manage two tasks, just run + // single task and use poll interfaces, basically manual state machine, + // which is simple here. + // + // [1] https://github.com/tokio-rs/tls/issues/40 + // + // Completes only when the stream is over, technically on error currently. + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Poll::Ready(r) = self.as_mut().poll_read(cx) { + return Poll::Ready(r); + } + self.as_mut().poll_write(cx) + } +} + +impl ReplicationHandler<'_> { + // Poll reading, i.e. getting feedback and processing it. Completes only on error/end of stream. + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.pgb.poll_read_message(cx)) { + Ok(Some(msg)) => self.as_mut().handle_feedback(&msg)?, + Ok(None) => { + return Poll::Ready(Err(QueryError::Other(anyhow::anyhow!( + "EOF on replication stream" + )))) } - })?; - - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - runtime.block_on(async move { - let (inmem_state, persisted_state) = tli.get_state(); - // add persisted_state.timeline_start_lsn == Lsn(0) check - - // Walproposer gets special handling: safekeeper must give proposer all - // local WAL till the end, whether committed or not (walproposer will - // hang otherwise). That's because walproposer runs the consensus and - // synchronizes safekeepers on the most advanced one. - // - // There is a small risk of this WAL getting concurrently garbaged if - // another compute rises which collects majority and starts fixing log - // on this safekeeper itself. That's ok as (old) proposer will never be - // able to commit such WAL. - let stop_pos: Option = if spg.is_walproposer_recovery() { - let wal_end = tli.get_flush_lsn(); - Some(wal_end) - } else { - None + Err(err) => return Poll::Ready(Err(err.into())), }; + } + } - info!("Start replication from {:?} till {:?}", start_pos, stop_pos); - - // switch to copy - pgb.write_message(&BeMessage::CopyBothResponse)?; - - let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn); - - let mut wal_reader = WalReader::new( - spg.conf.workdir.clone(), - spg.conf.timeline_dir(&tli.ttid), - &persisted_state, - start_pos, - spg.conf.wal_backup_enabled, - )?; - - // buffer for wal sending, limited by MAX_SEND_SIZE - let mut send_buf = vec![0u8; MAX_SEND_SIZE]; - - // watcher for commit_lsn updates - let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx(); - - loop { - if let Some(stop_pos) = stop_pos { - if start_pos >= stop_pos { - break; /* recovery finished */ + fn handle_feedback(mut self: Pin<&mut Self>, msg: &FeMessage) -> Result<(), QueryError> { + match &msg { + FeMessage::CopyData(m) => { + // There's three possible data messages that the client is supposed to send here: + // `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`. + match m.first().cloned() { + Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { + // Note: deserializing is on m[1..] because we skip the tag byte. + self.feedback.hs_feedback = HotStandbyFeedback::des(&m[1..]) + .context("failed to deserialize HotStandbyFeedback")?; + self.tli + .update_replica_state(self.replica_id, self.feedback); } - end_pos = stop_pos; - } else { - /* Wait until we have some data to stream */ - let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?; + Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { + let _reply = StandbyReply::des(&m[1..]) + .context("failed to deserialize StandbyReply")?; + // This must be a regular postgres replica, + // because pageserver doesn't send this type of messages to safekeeper. + // Currently this is not implemented, so this message is ignored. - if let Some(lsn) = lsn { - end_pos = lsn; - } else { - // TODO: also check once in a while whether we are walsender - // to right pageserver. - if tli.should_walsender_stop(replica_id) { - // Shut down, timeline is suspended. - return Err(QueryError::from(io::Error::new( - io::ErrorKind::ConnectionAborted, - format!("end streaming to {:?}", spg.appname), - ))); - } + warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet."); + // timeline.update_replica_state(replica_id, Some(state)); + } + Some(NEON_STATUS_UPDATE_TAG_BYTE) => { + // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. + let buf = Bytes::copy_from_slice(&m[9..]); + let reply = ReplicationFeedback::parse(buf); - // timeout expired: request pageserver status - pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive { - sent_ptr: end_pos.0, - timestamp: get_current_timestamp(), - request_reply: true, - }))?; + trace!("ReplicationFeedback is {:?}", reply); + // Only pageserver sends ReplicationFeedback, so set the flag. + // This replica is the source of information to resend to compute. + self.feedback.pageserver_feedback = Some(reply); + + self.tli + .update_replica_state(self.replica_id, self.feedback); + } + _ => warn!("unexpected message {:?}", msg), + } + } + FeMessage::CopyFail => { + // XXX we should probably (tell pgb to) close the socket, as + // CopyFail in duplex copy is somewhat unexpected (at least to + // PG walsender; evidently client should finish it with + // CopyDone). Note that sync rust-postgres client (which we + // don't use anymore) hangs otherwise. + // https://github.com/sfackler/rust-postgres/issues/755 + // https://github.com/neondatabase/neon/issues/935 + // + return Err(anyhow::anyhow!("unexpected CopyFail").into()); + } + _ => { + return Err( + anyhow::anyhow!("unexpected message {:?} in replication stream", msg).into(), + ); + } + }; + Ok(()) + } + + // Poll writing, i.e. sending more WAL. Completes only on error or when we + // decide to shutdown connection -- receiver is caughtup and there is no + // active computes; this is still handled as Err though. + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // send while we don't block or error out + loop { + match &mut self.write_state { + WriteState::WaitWal(fut) => match ready!(fut.as_mut().poll(cx))? { + Some(lsn) => { + self.end_pos = lsn; + self.as_mut().start_read_wal(); continue; } + // Timed out waiting for WAL, send keepalive and possibly terminate. + None => { + if self.tli.should_walsender_stop(self.replica_id) { + // Terminate if there is nothing more to send. + // TODO close the stream properly + return Poll::Ready(Err(anyhow::anyhow!(format!( + "ending streaming to {:?} at {}, receiver is caughtup and there is no computes", + self.appname, self.start_pos, + )).into())); + } + let end_pos = self.end_pos.0; + self.pgb + .write_message(&BeMessage::KeepAlive(WalSndKeepAlive { + sent_ptr: end_pos, + timestamp: get_current_timestamp(), + request_reply: true, + }))?; + self.write_state = WriteState::FlushWal; /* flush KA */ + } + }, + WriteState::ReadWal(fut) => { + let read_len = ready!(fut.as_mut().poll(cx))?; + assert!(read_len > 0, "read_len={}", read_len); + let write_ctx_clone = self.write_ctx.clone(); + let send_buf = &write_ctx_clone.send_buf.borrow()[..read_len]; + let (start_pos, end_pos) = (self.start_pos.0, self.end_pos.0); + // write data to the output buffer + self.pgb + .write_message(&BeMessage::XLogData(XLogDataBody { + wal_start: start_pos, + wal_end: end_pos, + timestamp: get_current_timestamp(), + data: send_buf, + })) + .context("Failed to write XLogData")?; + // and flush it + self.write_state = WriteState::FlushWal; + } + WriteState::FlushWal => { + ready!(self.pgb.poll_flush(cx))?; + // If we are streaming to walproposer, check it is time to stop. + if let Some(stop_pos) = self.stop_pos { + if self.start_pos >= stop_pos { + // recovery finished + // TODO close the stream properly + return Poll::Ready(Err(anyhow::anyhow!(format!( + "ending streaming to walproposer at {}, receiver is caughtup and there is no computes", + self.start_pos)).into())); + } + self.as_mut().start_read_wal(); + continue; + } else { + // if we don't know next portion is already available, wait + // for it; otherwise proceed to sending + if self.end_pos <= self.start_pos { + self.as_mut().start_wait_wal(); + } else { + self.as_mut().start_read_wal(); + } + } } - - let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; - let send_size = min(send_size, send_buf.len()); - - let send_buf = &mut send_buf[..send_size]; - - // read wal into buffer - let send_size = wal_reader.read(send_buf).await?; - let send_buf = &send_buf[..send_size]; - - // Write some data to the network socket. - pgb.write_message(&BeMessage::XLogData(XLogDataBody { - wal_start: start_pos.0, - wal_end: end_pos.0, - timestamp: get_current_timestamp(), - data: send_buf, - })) - .context("Failed to send XLogData")?; - - start_pos += send_size as u64; - trace!("sent WAL up to {}", start_pos); } + } + } - Ok(()) - }) + // Start waiting for WAL, creating future doing that. + fn start_wait_wal(mut self: Pin<&mut Self>) { + let mut commit_lsn_watch_rx = self.tli.get_commit_lsn_watch_rx(); + let start_pos = self.start_pos; + let wait_wal_fut = async move { wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await }; + self.write_state = WriteState::WaitWal(Box::pin(wait_wal_fut)); + } + + // Switch into reading WAL state, creating Future doing that. + fn start_read_wal(mut self: Pin<&mut Self>) { + let mut send_size = self.end_pos.checked_sub(self.start_pos).unwrap().0 as usize; + send_size = min(send_size, self.write_ctx.send_buf.borrow().len()); + let write_ctx_fut = self.write_ctx.clone(); + let read_wal_fut = async move { + let mut wal_reader_ref = write_ctx_fut.wal_reader.borrow_mut_send(); + let mut send_buf_ref = write_ctx_fut.send_buf.borrow_mut_send(); + + let send_buf = &mut send_buf_ref[..send_size]; + wal_reader_ref.read(send_buf).await + }; + self.write_state = WriteState::ReadWal(Box::pin(read_wal_fut)); } } const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); -// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn. +// Wait until we have commit_lsn > lsn or timeout expires. Returns +// - Ok(Some(commit_lsn)) if needed lsn is successfully observed; +// - Ok(None) if timeout expired; +// - Err in case of error (if watch channel is in trouble, shouldn't happen). async fn wait_for_lsn(rx: &mut Receiver, lsn: Lsn) -> anyhow::Result> { let commit_lsn: Lsn = *rx.borrow(); if commit_lsn > lsn { diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index fc971ca753..95262c15d5 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -346,7 +346,9 @@ impl WalBackupTask { backup_lsn, commit_lsn, e ); - retry_attempt = retry_attempt.saturating_add(1); + if retry_attempt < u32::MAX { + retry_attempt += 1; + } } } } @@ -385,7 +387,7 @@ async fn backup_single_segment( ) -> Result<()> { let segment_file_path = seg.file_path(timeline_dir)?; let remote_segment_path = segment_file_path - .strip_prefix(workspace_dir) + .strip_prefix(&workspace_dir) .context("Failed to strip workspace dir prefix") .and_then(RemotePath::new) .with_context(|| { @@ -467,7 +469,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize pub async fn read_object( file_path: &RemotePath, offset: u64, -) -> anyhow::Result>> { +) -> anyhow::Result>> { let storage = REMOTE_STORAGE .get() .context("Failed to get remote storage")? diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 3ca651d060..e43f1d7cb2 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -2,36 +2,54 @@ //! WAL service listens for client connections and //! receive WAL from wal_proposer and send it to WAL receivers //! +use anyhow::{Context, Result}; use regex::Regex; -use std::net::{TcpListener, TcpStream}; -use std::thread; +use std::{future, thread}; +use tokio::net::TcpStream; use tracing::*; use utils::postgres_backend_async::QueryError; use crate::handler::SafekeeperPostgresHandler; use crate::SafeKeeperConf; -use utils::postgres_backend::{AuthType, PostgresBackend}; +use utils::postgres_backend_async::{AuthType, PostgresBackend}; /// Accept incoming TCP connections and spawn them into a background thread. -pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! { - loop { - match listener.accept() { - Ok((socket, peer_addr)) => { - debug!("accepted connection from {}", peer_addr); - let conf = conf.clone(); +pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .context("create runtime") + // todo catch error in main thread + .expect("failed to create runtime"); - let _ = thread::Builder::new() - .name("WAL service thread".into()) - .spawn(move || { - if let Err(err) = handle_socket(socket, conf) { - error!("connection handler exited: {}", err); - } - }) - .unwrap(); + runtime + .block_on(async move { + // Tokio's from_std won't do this for us, per its comment. + pg_listener.set_nonblocking(true)?; + let listener = tokio::net::TcpListener::from_std(pg_listener)?; + + loop { + match listener.accept().await { + Ok((socket, peer_addr)) => { + debug!("accepted connection from {}", peer_addr); + let conf = conf.clone(); + + let _ = thread::Builder::new() + .name("WAL service thread".into()) + .spawn(move || { + if let Err(err) = handle_socket(socket, conf) { + error!("connection handler exited: {}", err); + } + }) + .unwrap(); + } + Err(e) => error!("Failed to accept connection: {}", e), + } } - Err(e) => error!("Failed to accept connection: {}", e), - } - } + #[allow(unreachable_code)] // hint compiler the closure return type + Ok::<(), anyhow::Error>(()) + }) + .expect("listener failed") } // Get unique thread id (Rust internal), with ThreadId removed for shorter printing @@ -44,9 +62,14 @@ fn get_tid() -> u64 { /// This is run by `thread_main` above, inside a background thread. /// -fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> { +fn handle_socket(mut socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> { let _enter = info_span!("", tid = ?get_tid()).entered(); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + let local = tokio::task::LocalSet::new(); + socket.set_nodelay(true)?; let auth_type = match conf.auth { @@ -54,9 +77,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr Some(_) => AuthType::NeonJWT, }; let mut conn_handler = SafekeeperPostgresHandler::new(conf); - let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?; - // libpq replication protocol between safekeeper and replicas/pagers - pgbackend.run(&mut conn_handler)?; + let pgbackend = PostgresBackend::new(socket, auth_type, None)?; + // libpq protocol between safekeeper and walproposer / pageserver + // We don't use shutdown. + local.block_on( + &runtime, + pgbackend.run(&mut conn_handler, || future::pending::<()>()), + )?; Ok(()) } diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index 41457868fe..ed2ee6bfc7 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -450,7 +450,7 @@ pub struct WalReader { timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn, - wal_segment: Option>>, + wal_segment: Option>>, // S3 will be used to read WAL if LSN is not available locally enable_remote_read: bool, @@ -491,6 +491,11 @@ impl WalReader { }) } + pub async fn fake_read(&mut self) -> Result { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + Ok(self.pos.0 as usize) + } + pub async fn read(&mut self, buf: &mut [u8]) -> Result { let mut wal_segment = match self.wal_segment.take() { Some(reader) => reader, @@ -517,7 +522,7 @@ impl WalReader { } /// Open WAL segment at the current position of the reader. - async fn open_segment(&self) -> Result>> { + async fn open_segment(&self) -> Result>> { let xlogoff = self.pos.segment_offset(self.wal_seg_size); let segno = self.pos.segment_number(self.wal_seg_size); let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);