diff --git a/Cargo.lock b/Cargo.lock index ab2f69929e..e380e72dc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -913,6 +913,7 @@ dependencies = [ "once_cell", "pageserver_api", "postgres", + "postgres_backend", "postgres_connection", "regex", "reqwest", @@ -2696,7 +2697,6 @@ dependencies = [ "tokio-postgres-rustls", "tokio-rustls", "tracing", - "utils", "workspace_hack", ] @@ -2922,6 +2922,7 @@ dependencies = [ "opentelemetry", "parking_lot", "pin-project-lite", + "postgres_backend", "pq_proto", "prometheus", "rand", @@ -3301,15 +3302,6 @@ dependencies = [ "base64 0.21.0", ] -[[package]] -name = "rustls-split" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3" -dependencies = [ - "rustls", -] - [[package]] name = "rustversion" version = "1.0.11" @@ -3346,6 +3338,7 @@ dependencies = [ "parking_lot", "postgres", "postgres-protocol", + "postgres_backend", "postgres_ffi", "pq_proto", "regex", @@ -4539,12 +4532,8 @@ dependencies = [ "metrics", "nix", "once_cell", - "pq_proto", "rand", "routerify", - "rustls", - "rustls-pemfile", - "rustls-split", "sentry", "serde", "serde_json", @@ -4858,14 +4847,19 @@ name = "workspace_hack" version = "0.1.0" dependencies = [ "anyhow", + "byteorder", "bytes", "chrono", "clap 4.1.4", "crossbeam-utils", + "digest", "either", "fail", "futures", + "futures-channel", + "futures-core", "futures-executor", + "futures-sink", "futures-util", "hashbrown 0.12.3", "indexmap", @@ -4890,6 +4884,7 @@ dependencies = [ "socket2", "syn", "tokio", + "tokio-rustls", "tokio-util", "tonic", "tower", diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 309887e1fa..ba39747e03 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -24,6 +24,7 @@ url.workspace = true # Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api # instead, so that recompile times are better. pageserver_api.workspace = true +postgres_backend.workspace = true safekeeper_api.workspace = true postgres_connection.workspace = true storage_broker.workspace = true diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 4b2aa3c957..49b1d31dbc 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -17,6 +17,7 @@ use pageserver_api::{ DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR, DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR, }; +use postgres_backend::AuthType; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT, @@ -30,7 +31,6 @@ use utils::{ auth::{Claims, Scope}, id::{NodeId, TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, project_git_version, }; diff --git a/control_plane/src/compute.rs b/control_plane/src/compute.rs index 8731cf2583..b7029aabc5 100644 --- a/control_plane/src/compute.rs +++ b/control_plane/src/compute.rs @@ -11,10 +11,10 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{Context, Result}; +use postgres_backend::AuthType; use utils::{ id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, }; use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION}; diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 003152c578..09180d96c4 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -5,6 +5,7 @@ use anyhow::{bail, ensure, Context}; +use postgres_backend::AuthType; use reqwest::Url; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -19,7 +20,6 @@ use std::process::{Command, Stdio}; use utils::{ auth::{encode_from_key_file, Claims, Scope}, id::{NodeId, TenantId, TenantTimelineId, TimelineId}, - postgres_backend::AuthType, }; use crate::safekeeper::SafekeeperNode; diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index c49bd39f09..4b7180c250 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -11,6 +11,7 @@ use anyhow::{bail, Context}; use pageserver_api::models::{ TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo, }; +use postgres_backend::AuthType; use postgres_connection::{parse_host_port, PgConnectionConfig}; use reqwest::blocking::{Client, RequestBuilder, Response}; use reqwest::{IntoUrl, Method}; @@ -20,7 +21,6 @@ use utils::{ http::error::HttpErrorBody, id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, }; use crate::{background_process, local_env::LocalEnv}; diff --git a/libs/postgres_backend/Cargo.toml b/libs/postgres_backend/Cargo.toml index bead77c4d6..8e249c09f7 100644 --- a/libs/postgres_backend/Cargo.toml +++ b/libs/postgres_backend/Cargo.toml @@ -17,7 +17,6 @@ tokio-rustls.workspace = true tracing.workspace = true pq_proto.workspace = true -utils.workspace = true workspace_hack.workspace = true [dev-dependencies] diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index 6e96e65a52..ba28add9f9 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -2,29 +2,26 @@ //! To use, create PostgresBackend and run() it, passing the Handler //! implementation determining how to process the queries. Currently its API //! is rather narrow, but we can extend it once required. - use anyhow::Context; -use bytes::{Buf, Bytes, BytesMut}; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR}; -use std::io; +use bytes::Bytes; +use futures::pin_mut; +use serde::{Deserialize, Serialize}; +use std::io::ErrorKind; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; -use std::task::Poll; -use std::{future::Future, task::ready}; -use tracing::{debug, error, info, trace}; -use utils::postgres_backend::AuthType; - -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use std::task::{ready, Poll}; +use std::{fmt, io}; +use std::{future::Future, str::FromStr}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; +use tracing::{debug, error, info, trace}; -pub fn is_expected_io_error(e: &io::Error) -> bool { - use io::ErrorKind::*; - matches!( - e.kind(), - ConnectionRefused | ConnectionAborted | ConnectionReset - ) -} +use pq_proto::framed::{Framed, FramedReader, FramedWriter}; +use pq_proto::{ + BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, + SQLSTATE_SUCCESSFUL_COMPLETION, +}; /// An error, occurred during query processing: /// either during the connection ([`ConnectionError`]) or before/after it. @@ -53,12 +50,20 @@ impl QueryError { } } +pub fn is_expected_io_error(e: &io::Error) -> bool { + use io::ErrorKind::*; + matches!( + e.kind(), + ConnectionRefused | ConnectionAborted | ConnectionReset + ) +} + #[async_trait::async_trait] pub trait Handler { /// Handle single query. /// postgres_backend will issue ReadyForQuery after calling this (this /// might be not what we want after CopyData streaming, but currently we don't - /// care). + /// care). It will also flush out the output buffer. async fn process_query( &mut self, pgb: &mut PostgresBackend, @@ -92,9 +97,13 @@ pub trait Handler { /// XXX: The order of the constructors matters. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] pub enum ProtoState { + /// Nothing happened yet. Initialization, + /// Encryption handshake is done; waiting for encrypted Startup message. Encrypted, + /// Waiting for password (auth token). Authentication, + /// Performed handshake and auth, ReadyForQuery is issued. Established, Closed, } @@ -105,15 +114,13 @@ pub enum ProcessMsgResult { Break, } -/// Always-writeable sock_split stream. -/// May not be readable. See [`PostgresBackend::take_stream_in`] -pub enum Stream { - Unencrypted(BufReader), - Tls(Box>>), - Broken, +/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite. +pub enum MaybeTlsStream { + Unencrypted(tokio::net::TcpStream), + Tls(Box>), } -impl AsyncWrite for Stream { +impl AsyncWrite for MaybeTlsStream { fn poll_write( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -122,14 +129,12 @@ impl AsyncWrite for Stream { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), - Self::Broken => unreachable!(), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx), Self::Tls(stream) => Pin::new(stream).poll_flush(cx), - Self::Broken => unreachable!(), } } fn poll_shutdown( @@ -139,11 +144,10 @@ impl AsyncWrite for Stream { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx), Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), - Self::Broken => unreachable!(), } } } -impl AsyncRead for Stream { +impl AsyncRead for MaybeTlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -152,18 +156,96 @@ impl AsyncRead for Stream { match self.get_mut() { Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), - Self::Broken => unreachable!(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum AuthType { + Trust, + // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT + NeonJWT, +} + +impl FromStr for AuthType { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "Trust" => Ok(Self::Trust), + "NeonJWT" => Ok(Self::NeonJWT), + _ => anyhow::bail!("invalid value \"{s}\" for auth type"), + } + } +} + +impl fmt::Display for AuthType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + AuthType::Trust => "Trust", + AuthType::NeonJWT => "NeonJWT", + }) + } +} + +/// Either full duplex Framed or write only half; the latter is left in +/// PostgresBackend after call to `split`. In principle we could always store a +/// pair of splitted handles, but that would force to to pay splitting price +/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver). +enum MaybeWriteOnly { + Full(Framed), + WriteOnly(FramedWriter), + Broken, // temporary value palmed off during the split +} + +impl MaybeWriteOnly { + async fn read_startup_message(&mut self) -> Result, ConnectionError> { + match self { + MaybeWriteOnly::Full(framed) => framed.read_startup_message().await, + MaybeWriteOnly::WriteOnly(_) => { + Err(io::Error::new(ErrorKind::Other, "reading from write only half").into()) + } + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn read_message(&mut self) -> Result, ConnectionError> { + match self { + MaybeWriteOnly::Full(framed) => framed.read_message().await, + MaybeWriteOnly::WriteOnly(_) => { + Err(io::Error::new(ErrorKind::Other, "reading from write only half").into()) + } + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> { + match self { + MaybeWriteOnly::Full(framed) => framed.write_message(msg), + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg), + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn flush(&mut self) -> io::Result<()> { + match self { + MaybeWriteOnly::Full(framed) => framed.flush().await, + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await, + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn shutdown(&mut self) -> io::Result<()> { + match self { + MaybeWriteOnly::Full(framed) => framed.shutdown().await, + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await, + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), } } } pub struct PostgresBackend { - stream: Stream, - - // Output buffer. c.f. BeMessage::write why we are using BytesMut here. - // The data between 0 and "current position" as tracked by the bytes::Buf - // implementation of BytesMut, have already been written. - buf_out: BytesMut, + framed: MaybeWriteOnly, pub state: ProtoState, @@ -183,7 +265,7 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec { query_string } -// Cast a byte slice to a string slice, dropping null terminator if there's one. +/// Cast a byte slice to a string slice, dropping null terminator if there's one. fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); std::str::from_utf8(without_null).map_err(|e| e.into()) @@ -196,10 +278,10 @@ impl PostgresBackend { tls_config: Option>, ) -> io::Result { let peer_addr = socket.peer_addr()?; + let stream = MaybeTlsStream::Unencrypted(socket); Ok(Self { - stream: Stream::Unencrypted(BufReader::new(socket)), - buf_out: BytesMut::with_capacity(10 * 1024), + framed: MaybeWriteOnly::Full(Framed::new(stream)), state: ProtoState::Initialization, auth_type, tls_config, @@ -211,30 +293,52 @@ impl PostgresBackend { &self.peer_addr } - /// Read full message or return None if connection is closed. - pub async fn read_message(&mut self) -> Result, QueryError> { - use ProtoState::*; - match self.state { - Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await, - Authentication | Established => FeMessage::read_fut(&mut self.stream).await, - Closed => Ok(None), + /// Read full message or return None if connection is cleanly closed with no + /// unprocessed data. + pub async fn read_message(&mut self) -> Result, ConnectionError> { + if let ProtoState::Closed = self.state { + Ok(None) + } else { + let m = self.framed.read_message().await?; + trace!("read msg {:?}", m); + Ok(m) } - .map_err(QueryError::from) + } + + /// Write message into internal output buffer, doesn't flush it. Technically + /// error type can be only ProtocolError here (if, unlikely, serialization + /// fails), but callers typically wrap it anyway. + pub fn write_message_noflush( + &mut self, + message: &BeMessage<'_>, + ) -> Result<&mut Self, ConnectionError> { + self.framed.write_message_noflush(message)?; + trace!("wrote msg {:?}", message); + Ok(self) } /// Flush output buffer into the socket. pub async fn flush(&mut self) -> io::Result<()> { - while self.buf_out.has_remaining() { - let bytes_written = self.stream.write(self.buf_out.chunk()).await?; - self.buf_out.advance(bytes_written); - } - self.buf_out.clear(); - Ok(()) + self.framed.flush().await } - /// Write message into internal output buffer. - pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buf_out, message)?; + /// Polling version of `flush()`, saves the caller need to pin. + pub fn poll_flush( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let flush_fut = self.flush(); + pin_mut!(flush_fut); + flush_fut.poll(cx) + } + + /// Write message into internal output buffer and flush it to the stream. + pub async fn write_message( + &mut self, + message: &BeMessage<'_>, + ) -> Result<&mut Self, ConnectionError> { + self.write_message_noflush(message)?; + self.flush().await?; Ok(self) } @@ -246,26 +350,7 @@ impl PostgresBackend { CopyDataWriter { pgb: self } } - /// A polling function that tries to write all the data from 'buf_out' to the - /// underlying stream. - fn poll_write_buf( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - while self.buf_out.has_remaining() { - match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) { - Ok(bytes_written) => self.buf_out.advance(bytes_written), - Err(err) => return Poll::Ready(Err(err)), - } - } - Poll::Ready(Ok(())) - } - - fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_flush(cx) - } - - // Wrapper for run_message_loop() that shuts down socket when we are done + /// Wrapper for run_message_loop() that shuts down socket when we are done pub async fn run( mut self, handler: &mut impl Handler, @@ -276,7 +361,9 @@ impl PostgresBackend { S: Future, { let ret = self.run_message_loop(handler, shutdown_watcher).await; - let _ = self.stream.shutdown(); + // socket might be already closed, e.g. if previously received error, + // so ignore result. + self.framed.shutdown().await.ok(); ret } @@ -300,30 +387,12 @@ impl PostgresBackend { return Ok(()) }, - result = async { - while self.state < ProtoState::Established { - if let Some(msg) = self.read_message().await? { - trace!("got message {msg:?} during handshake"); - - match self.process_handshake_message(handler, msg).await? { - ProcessMsgResult::Continue => { - self.flush().await?; - continue; - } - ProcessMsgResult::Break => { - trace!("postgres backend to {:?} exited during handshake", self.peer_addr); - return Ok(()); - } - } - } else { - trace!("postgres backend to {:?} exited during handshake", self.peer_addr); - return Ok(()); - } - } - Ok::<(), QueryError>(()) - } => { + result = self.handshake(handler) => { // Handshake complete. result?; + if self.state == ProtoState::Closed { + return Ok(()); // EOF during handshake + } } ); @@ -355,114 +424,207 @@ impl PostgresBackend { Ok(()) } - async fn start_tls(&mut self) -> anyhow::Result<()> { - if let Stream::Unencrypted(plain_stream) = - std::mem::replace(&mut self.stream, Stream::Broken) - { - let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap()); - let tls_stream = acceptor.accept(plain_stream).await?; - - self.stream = Stream::Tls(Box::new(tls_stream)); - return Ok(()); - }; - anyhow::bail!("TLS already started"); - } - - async fn process_handshake_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - ) -> Result { - assert!(self.state < ProtoState::Established); - let have_tls = self.tls_config.is_some(); - match msg { - FeMessage::StartupPacket(m) => { - trace!("got startup message {m:?}"); - - match m { - FeStartupPacket::SslRequest => { - debug!("SSL requested"); - - self.write_message_noflush(&BeMessage::EncryptionResponse(have_tls))?; - if have_tls { - self.start_tls().await?; - self.state = ProtoState::Encrypted; - } - } - FeStartupPacket::GssEncRequest => { - debug!("GSS requested"); - self.write_message_noflush(&BeMessage::EncryptionResponse(false))?; - } - FeStartupPacket::StartupMessage { .. } => { - if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message_noflush(&BeMessage::ErrorResponse( - "must connect with TLS", - None, - ))?; - return Err(QueryError::Other(anyhow::anyhow!( - "client did not connect with TLS" - ))); - } - - // NB: startup() may change self.auth_type -- we are using that in proxy code - // to bypass auth for new users. - handler.startup(self, &m)?; - - match self.auth_type { - AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - // The async python driver requires a valid server_version - .write_message_noflush(&BeMessage::server_version("14.1"))? - .write_message_noflush(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - AuthType::NeonJWT => { - self.write_message_noflush( - &BeMessage::AuthenticationCleartextPassword, - )?; - self.state = ProtoState::Authentication; - } - } - } - FeStartupPacket::CancelRequest { .. } => { - self.state = ProtoState::Closed; - return Ok(ProcessMsgResult::Break); - } - } + /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake. + async fn tls_upgrade( + src: MaybeTlsStream, + tls_config: Arc, + ) -> anyhow::Result { + match src { + MaybeTlsStream::Unencrypted(s) => { + let acceptor = TlsAcceptor::from(tls_config); + let tls_stream = acceptor.accept(s).await?; + Ok(MaybeTlsStream::Tls(Box::new(tls_stream))) } - - FeMessage::PasswordMessage(m) => { - trace!("got password message '{:?}'", m); - - assert!(self.state == ProtoState::Authentication); - - match self.auth_type { - AuthType::Trust => unreachable!(), - AuthType::NeonJWT => { - let (_, jwt_response) = m.split_last().context("protocol violation")?; - - if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message_noflush(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - return Err(e); - } - } - } - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - .write_message_noflush(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - - _ => { - self.state = ProtoState::Closed; - return Ok(ProcessMsgResult::Break); + MaybeTlsStream::Tls(_) => { + anyhow::bail!("TLS already started"); } } - Ok(ProcessMsgResult::Continue) + } + + async fn start_tls(&mut self) -> anyhow::Result<()> { + // temporary replace stream with fake to cook TLS one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(framed) => { + let tls_config = self + .tls_config + .as_ref() + .context("start_tls called without conf")? + .clone(); + let tls_framed = framed + .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config)) + .await?; + // push back ready TLS stream + self.framed = MaybeWriteOnly::Full(tls_framed); + Ok(()) + } + MaybeWriteOnly::WriteOnly(_) => { + anyhow::bail!("TLS upgrade attempt in split state") + } + MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"), + } + } + + /// Split off owned read part from which messages can be read in different + /// task/thread. + pub fn split(&mut self) -> anyhow::Result { + // temporary replace stream with fake to cook split one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(framed) => { + let (reader, writer) = framed.split(); + self.framed = MaybeWriteOnly::WriteOnly(writer); + Ok(PostgresBackendReader(reader)) + } + MaybeWriteOnly::WriteOnly(_) => { + anyhow::bail!("PostgresBackend is already split") + } + MaybeWriteOnly::Broken => panic!("split on framed in invalid state"), + } + } + + /// Join read part back. + pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> { + // temporary replace stream with fake to cook joined one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(_) => { + anyhow::bail!("PostgresBackend is not split") + } + MaybeWriteOnly::WriteOnly(writer) => { + let joined = Framed::unsplit(reader.0, writer); + self.framed = MaybeWriteOnly::Full(joined); + Ok(()) + } + MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"), + } + } + + /// Perform handshake with the client, transitioning to Established. + /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()). + async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> { + while self.state < ProtoState::Authentication { + match self.framed.read_startup_message().await? { + Some(msg) => { + self.process_startup_message(handler, msg).await?; + } + None => { + trace!( + "postgres backend to {:?} received EOF during handshake", + self.peer_addr + ); + self.state = ProtoState::Closed; + return Ok(()); + } + } + } + + // Perform auth, if needed. + if self.state == ProtoState::Authentication { + match self.framed.read_message().await? { + Some(FeMessage::PasswordMessage(m)) => { + assert!(self.auth_type == AuthType::NeonJWT); + + let (_, jwt_response) = m.split_last().context("protocol violation")?; + + if let Err(e) = handler.check_auth_jwt(self, jwt_response) { + self.write_message_noflush(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; + return Err(e); + } + + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeMessage::CLIENT_ENCODING)? + .write_message(&BeMessage::ReadyForQuery) + .await?; + self.state = ProtoState::Established; + } + Some(m) => { + return Err(QueryError::Other(anyhow::anyhow!( + "Unexpected message {:?} while waiting for handshake", + m + ))); + } + None => { + trace!( + "postgres backend to {:?} received EOF during auth", + self.peer_addr + ); + self.state = ProtoState::Closed; + return Ok(()); + } + } + } + + Ok(()) + } + + /// Process startup packet: + /// - transition to Established if auth type is trust + /// - transition to Authentication if auth type is NeonJWT. + /// - or perform TLS handshake -- then need to call this again to receive + /// actual startup packet. + async fn process_startup_message( + &mut self, + handler: &mut impl Handler, + msg: FeStartupPacket, + ) -> Result<(), QueryError> { + assert!(self.state < ProtoState::Authentication); + let have_tls = self.tls_config.is_some(); + match msg { + FeStartupPacket::SslRequest => { + debug!("SSL requested"); + + self.write_message(&BeMessage::EncryptionResponse(have_tls)) + .await?; + + if have_tls { + self.start_tls().await?; + self.state = ProtoState::Encrypted; + } + } + FeStartupPacket::GssEncRequest => { + debug!("GSS requested"); + self.write_message(&BeMessage::EncryptionResponse(false)) + .await?; + } + FeStartupPacket::StartupMessage { .. } => { + if have_tls && !matches!(self.state, ProtoState::Encrypted) { + self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None)) + .await?; + return Err(QueryError::Other(anyhow::anyhow!( + "client did not connect with TLS" + ))); + } + + // NB: startup() may change self.auth_type -- we are using that in proxy code + // to bypass auth for new users. + handler.startup(self, &msg)?; + + match self.auth_type { + AuthType::Trust => { + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeMessage::CLIENT_ENCODING)? + .write_message_noflush(&BeMessage::INTEGER_DATETIMES)? + // The async python driver requires a valid server_version + .write_message_noflush(&BeMessage::server_version("14.1"))? + .write_message(&BeMessage::ReadyForQuery) + .await?; + self.state = ProtoState::Established; + } + AuthType::NeonJWT => { + self.write_message(&BeMessage::AuthenticationCleartextPassword) + .await?; + self.state = ProtoState::Authentication; + } + } + } + FeStartupPacket::CancelRequest { .. } => { + return Err(QueryError::Other(anyhow::anyhow!( + "Unexpected CancelRequest message during handshake" + ))); + } + } + Ok(()) } async fn process_message( @@ -476,10 +638,6 @@ impl PostgresBackend { assert!(self.state == ProtoState::Established); match msg { - FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => { - return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); - } - FeMessage::Query(body) => { // remove null terminator let query_string = cstr_to_str(&body)?; @@ -540,16 +698,114 @@ impl PostgresBackend { // We prefer explicit pattern matching to wildcards, because // this helps us spot the places where new variants are missing - FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { + FeMessage::CopyData(_) + | FeMessage::CopyDone + | FeMessage::CopyFail + | FeMessage::PasswordMessage(_) + | FeMessage::StartupPacket(_) => { return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message type: {:?}", - msg + "unexpected message type: {msg:?}", ))); } } Ok(ProcessMsgResult::Continue) } + + /// Log as info/error result of handling COPY stream and send back + /// ErrorResponse if that makes sense. Shutdown the stream if we got + /// Terminate. TODO: transition into waiting for Sync msg if we initiate the + /// close. + pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) { + use CopyStreamHandlerEnd::*; + + let expected_end = match &end { + ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true, + CopyStreamHandlerEnd::Disconnected(ConnectionError::Socket(io_error)) + if is_expected_io_error(io_error) => + { + true + } + _ => false, + }; + if expected_end { + info!("terminated: {:#}", end); + } else { + error!("terminated: {:?}", end); + } + + // Note: no current usages ever send this + if let CopyDone = &end { + if let Err(e) = self.write_message(&BeMessage::CopyDone).await { + error!("failed to send CopyDone: {}", e); + } + } + + if let Terminate = &end { + self.state = ProtoState::Closed; + } + + let err_to_send_and_errcode = match &end { + ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)), + Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)), + // Note: CopyFail in duplex copy is somewhat unexpected (at least to + // PG walsender; evidently and per my docs reading client should + // finish it with CopyDone). It is not a problem to recover from it + // finishing the stream in both directions like we do, but note that + // sync rust-postgres client (which we don't use anymore) hangs if + // socket is not closed here. + // https://github.com/sfackler/rust-postgres/issues/755 + // https://github.com/neondatabase/neon/issues/935 + // + // Currently, the version of tokio_postgres replication patch we use + // sends this when it closes the stream (e.g. pageserver decided to + // switch conn to another safekeeper and client gets dropped). + // Moreover, seems like 'connection' task errors with 'unexpected + // message from server' when it receives ErrorResponse (anything but + // CopyData/CopyDone) back. + CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)), + _ => None, + }; + if let Some((err, errcode)) = err_to_send_and_errcode { + if let Err(ee) = self + .write_message(&BeMessage::ErrorResponse(&err, Some(errcode))) + .await + { + error!("failed to send ErrorResponse: {}", ee); + } + } + } +} + +pub struct PostgresBackendReader(FramedReader); + +impl PostgresBackendReader { + /// Read full message or return None if connection is cleanly closed with no + /// unprocessed data. + pub async fn read_message(&mut self) -> Result, ConnectionError> { + let m = self.0.read_message().await?; + trace!("read msg {:?}", m); + Ok(m) + } + + /// Get CopyData contents of the next message in COPY stream or error + /// closing it. The error type is wider than actual errors which can happen + /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for + /// current callers. + pub async fn read_copy_message(&mut self) -> Result { + match self.read_message().await? { + Some(msg) => match msg { + FeMessage::CopyData(m) => Ok(m), + FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone), + FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail), + FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate), + _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol( + format!("unexpected message in COPY stream {:?}", msg), + ))), + }, + None => Err(CopyStreamHandlerEnd::EOF), + } + } } /// @@ -572,16 +828,19 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> { // It's not strictly required to flush between each message, but makes it easier // to view in wireshark, and usually the messages that the callers write are // decently-sized anyway. - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), + if let Err(err) = ready!(this.pgb.poll_flush(cx)) { + return Poll::Ready(Err(err)); } // CopyData // XXX: if the input is large, we should split it into multiple messages. // Not sure what the threshold should be, but the ultimate hard limit is that // the length cannot exceed u32. - this.pgb.write_message_noflush(&BeMessage::CopyData(buf))?; + this.pgb + .write_message_noflush(&BeMessage::CopyData(buf)) + // write_message only writes to the buffer, so it can fail iff the + // message is invaid, but CopyData can't be invalid. + .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?; Poll::Ready(Ok(buf.len())) } @@ -591,21 +850,14 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> { cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.get_mut(); - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } this.pgb.poll_flush(cx) } + fn poll_shutdown( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.get_mut(); - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } this.pgb.poll_flush(cx) } } @@ -617,7 +869,7 @@ pub fn short_error(e: &QueryError) -> String { } } -pub fn log_query_error(query: &str, e: &QueryError) { +fn log_query_error(query: &str, e: &QueryError) { match e { QueryError::Disconnected(ConnectionError::Socket(io_error)) => { if is_expected_io_error(io_error) { @@ -634,3 +886,26 @@ pub fn log_query_error(query: &str, e: &QueryError) { } } } + +/// Something finishing handling of COPY stream, see handle_copy_stream_end. +/// This is not always a real error, but it allows to use ? and thiserror impls. +#[derive(thiserror::Error, Debug)] +pub enum CopyStreamHandlerEnd { + /// Handler initiates the end of streaming. + #[error("{0}")] + ServerInitiated(String), + #[error("received CopyDone")] + CopyDone, + #[error("received CopyFail")] + CopyFail, + #[error("received Terminate")] + Terminate, + #[error("EOF on COPY stream")] + EOF, + /// The connection was lost + #[error(transparent)] + Disconnected(#[from] ConnectionError), + /// Some other error + #[error(transparent)] + Other(#[from] anyhow::Error), +} diff --git a/libs/utils/tests/cert.pem b/libs/postgres_backend/tests/cert.pem similarity index 100% rename from libs/utils/tests/cert.pem rename to libs/postgres_backend/tests/cert.pem diff --git a/libs/utils/tests/key.pem b/libs/postgres_backend/tests/key.pem similarity index 100% rename from libs/utils/tests/key.pem rename to libs/postgres_backend/tests/key.pem diff --git a/libs/postgres_backend/tests/simple_select.rs b/libs/postgres_backend/tests/simple_select.rs new file mode 100644 index 0000000000..a310171c70 --- /dev/null +++ b/libs/postgres_backend/tests/simple_select.rs @@ -0,0 +1,139 @@ +/// Test postgres_backend_async with tokio_postgres +use once_cell::sync::Lazy; +use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError}; +use pq_proto::{BeMessage, RowDescriptor}; +use std::io::Cursor; +use std::{future, sync::Arc}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_postgres::config::SslMode; +use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::{Config, NoTls, SimpleQueryMessage}; +use tokio_postgres_rustls::MakeRustlsConnect; + +// generate client, server test streams +async fn make_tcp_pair() -> (TcpStream, TcpStream) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let client_stream = TcpStream::connect(addr).await.unwrap(); + let (server_stream, _) = listener.accept().await.unwrap(); + (client_stream, server_stream) +} + +struct TestHandler {} + +#[async_trait::async_trait] +impl Handler for TestHandler { + // return single col 'hey' for any query + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + _query_string: &str, + ) -> Result<(), QueryError> { + pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col( + b"hey", + )]))? + .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + Ok(()) + } +} + +// test that basic select works +#[tokio::test] +async fn simple_select() { + let (client_sock, server_sock) = make_tcp_pair().await; + + // create and run pgbackend + let pgbackend = + PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation"); + + tokio::spawn(async move { + let mut handler = TestHandler {}; + pgbackend.run(&mut handler, future::pending::<()>).await + }); + + let conf = Config::new(); + let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect"); + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0]; + if let SimpleQueryMessage::Row(row) = first_val { + let first_col = row.get(0).expect("first column"); + assert_eq!(first_col, "hey"); + } else { + panic!("expected SimpleQueryMessage::Row"); + } +} + +static KEY: Lazy = Lazy::new(|| { + let mut cursor = Cursor::new(include_bytes!("key.pem")); + rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()) +}); + +static CERT: Lazy = Lazy::new(|| { + let mut cursor = Cursor::new(include_bytes!("cert.pem")); + rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) +}); + +// test that basic select with ssl works +#[tokio::test] +async fn simple_select_ssl() { + let (client_sock, server_sock) = make_tcp_pair().await; + + let server_cfg = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![CERT.clone()], KEY.clone()) + .unwrap(); + let tls_config = Some(Arc::new(server_cfg)); + let pgbackend = + PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation"); + + tokio::spawn(async move { + let mut handler = TestHandler {}; + pgbackend.run(&mut handler, future::pending::<()>).await + }); + + let client_cfg = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates({ + let mut store = rustls::RootCertStore::empty(); + store.add(&CERT).unwrap(); + store + }) + .with_no_client_auth(); + let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg); + let tls_connect = >::make_tls_connect( + &mut make_tls_connect, + "localhost", + ) + .expect("make_tls_connect"); + + let mut conf = Config::new(); + conf.ssl_mode(SslMode::Require); + let (client, connection) = conf + .connect_raw(client_sock, tls_connect) + .await + .expect("connect"); + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0]; + if let SimpleQueryMessage::Row(row) = first_val { + let first_col = row.get(0).expect("first column"); + assert_eq!(first_col, "hey"); + } else { + panic!("expected SimpleQueryMessage::Row"); + } +} diff --git a/libs/pq_proto/src/framed.rs b/libs/pq_proto/src/framed.rs new file mode 100644 index 0000000000..7c33222e6e --- /dev/null +++ b/libs/pq_proto/src/framed.rs @@ -0,0 +1,175 @@ +//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from +//! the async stream. +use bytes::{Buf, BytesMut}; +use std::{ + future::Future, + io::{self, ErrorKind}, +}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; + +use crate::{BeMessage, ConnectionError, FeMessage, FeStartupPacket}; + +const INITIAL_CAPACITY: usize = 8 * 1024; + +/// Wraps async io `stream`, providing messages to write/flush + read Postgres +/// messages. +pub struct Framed { + stream: BufReader, + write_buf: BytesMut, +} + +impl Framed { + pub fn new(stream: S) -> Self { + Self { + stream: BufReader::new(stream), + write_buf: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } + + /// Get a shared reference to the underlying stream. + pub fn get_ref(&self) -> &S { + self.stream.get_ref() + } + + /// Extract the underlying stream. + pub fn into_inner(self) -> S { + self.stream.into_inner() + } + + /// Return new Framed with stream type transformed by async f, for TLS + /// upgrade. + pub async fn map_stream(self, f: F) -> Result, E> + where + F: FnOnce(S) -> Fut, + Fut: Future>, + { + let stream = f(self.stream.into_inner()).await?; + Ok(Framed { + stream: BufReader::new(stream), + write_buf: self.write_buf, + }) + } +} + +impl Framed { + pub async fn read_startup_message( + &mut self, + ) -> Result, ConnectionError> { + let msg = FeStartupPacket::read(&mut self.stream).await?; + + match msg { + Some(FeMessage::StartupPacket(packet)) => Ok(Some(packet)), + None => Ok(None), + _ => panic!("unreachable state"), + } + } + + pub async fn read_message(&mut self) -> Result, ConnectionError> { + FeMessage::read(&mut self.stream).await + } +} + +impl Framed { + /// Write next message to the output buffer; doesn't flush. + pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> { + BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into()) + } + + /// Flush out the buffer. This function is cancellation safe: it can be + /// interrupted and flushing will be continued in the next call. + pub async fn flush(&mut self) -> Result<(), io::Error> { + flush(&mut self.stream, &mut self.write_buf).await + } + + /// Flush out the buffer and shutdown the stream. + pub async fn shutdown(&mut self) -> Result<(), io::Error> { + shutdown(&mut self.stream, &mut self.write_buf).await + } +} + +impl Framed { + /// Split into owned read and write parts. Beware of potential issues with + /// using halves in different tasks on TLS stream: + /// https://github.com/tokio-rs/tls/issues/40 + pub fn split(self) -> (FramedReader, FramedWriter) { + let (read_half, write_half) = tokio::io::split(self.stream); + let reader = FramedReader { stream: read_half }; + let writer = FramedWriter { + stream: write_half, + write_buf: self.write_buf, + }; + (reader, writer) + } + + /// Join read and write parts back. + pub fn unsplit(reader: FramedReader, writer: FramedWriter) -> Self { + Self { + stream: reader.stream.unsplit(writer.stream), + write_buf: writer.write_buf, + } + } +} + +/// Read-only version of `Framed`. +pub struct FramedReader { + stream: ReadHalf>, +} + +impl FramedReader { + pub async fn read_message(&mut self) -> Result, ConnectionError> { + FeMessage::read(&mut self.stream).await + } +} + +/// Write-only version of `Framed`. +pub struct FramedWriter { + stream: WriteHalf>, + write_buf: BytesMut, +} + +impl FramedWriter { + /// Write next message to the output buffer; doesn't flush. + pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> { + BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into()) + } + + /// Flush out the buffer. This function is cancellation safe: it can be + /// interrupted and flushing will be continued in the next call. + pub async fn flush(&mut self) -> Result<(), io::Error> { + flush(&mut self.stream, &mut self.write_buf).await + } + + /// Flush out the buffer and shutdown the stream. + pub async fn shutdown(&mut self) -> Result<(), io::Error> { + shutdown(&mut self.stream, &mut self.write_buf).await + } +} + +async fn flush( + stream: &mut S, + write_buf: &mut BytesMut, +) -> Result<(), io::Error> { + while write_buf.has_remaining() { + let bytes_written = stream.write(write_buf.chunk()).await?; + if bytes_written == 0 { + return Err(io::Error::new( + ErrorKind::WriteZero, + "failed to write message", + )); + } + // The advanced part will be garbage collected, likely during shifting + // data left on next attempt to write to buffer when free space is not + // enough. + write_buf.advance(bytes_written); + } + write_buf.clear(); + stream.flush().await +} + +async fn shutdown( + stream: &mut S, + write_buf: &mut BytesMut, +) -> Result<(), io::Error> { + flush(stream, write_buf).await?; + stream.shutdown().await +} diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index b7995c840c..6980c4afae 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -2,8 +2,7 @@ //! //! on message formats. -// Tools for calling certain async methods in sync contexts. -pub mod sync; +pub mod framed; use anyhow::{ensure, Context, Result}; use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -13,12 +12,10 @@ use std::{ borrow::Cow, collections::HashMap, fmt, - future::Future, io::{self, Cursor}, str, time::{Duration, SystemTime}, }; -use sync::{AsyncishRead, SyncFuture}; use tokio::io::AsyncReadExt; use tracing::{trace, warn}; @@ -211,7 +208,7 @@ macro_rules! retry_read { pub enum ConnectionError { /// IO error during writing to or reading from the connection socket. #[error("Socket IO error: {0}")] - Socket(std::io::Error), + Socket(#[from] std::io::Error), /// Invalid packet was received from client #[error("Protocol error: {0}")] Protocol(String), @@ -238,87 +235,56 @@ impl ConnectionError { impl FeMessage { /// Read one message from the stream. /// This function returns `Ok(None)` in case of EOF. - /// One way to handle this properly: - /// - /// ``` - /// # use std::io; - /// # use pq_proto::FeMessage; - /// # - /// # fn process_message(msg: FeMessage) -> anyhow::Result<()> { - /// # Ok(()) - /// # }; - /// # - /// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> { - /// while let Some(msg) = FeMessage::read(stream)? { - /// process_message(msg)?; - /// } - /// - /// Ok(()) - /// } - /// ``` - #[inline(never)] - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { - Self::read_fut(&mut AsyncishRead(stream)).wait() - } - - /// Read one message from the stream. - /// See documentation for `Self::read`. - pub fn read_fut( - stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> + pub async fn read(stream: &mut Reader) -> Result, ConnectionError> where Reader: tokio::io::AsyncRead + Unpin, { // We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof. // SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and // AsyncReadExt methods of the stream. - SyncFuture::new(async move { - // Each libpq message begins with a message type byte, followed by message length - // If the client closes the connection, return None. But if the client closes the - // connection in the middle of a message, we will return an error. - let tag = match retry_read!(stream.read_u8().await) { - Ok(b) => b, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + // Each libpq message begins with a message type byte, followed by message length + // If the client closes the connection, return None. But if the client closes the + // connection in the middle of a message, we will return an error. + let tag = match retry_read!(stream.read_u8().await) { + Ok(b) => b, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ConnectionError::Socket(e)), + }; - // The message length includes itself, so it better be at least 4. - let len = retry_read!(stream.read_u32().await) - .map_err(ConnectionError::Socket)? - .checked_sub(4) - .ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?; + // The message length includes itself, so it better be at least 4. + let len = retry_read!(stream.read_u32().await) + .map_err(ConnectionError::Socket)? + .checked_sub(4) + .ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?; - let body = { - let mut buffer = vec![0u8; len as usize]; - stream - .read_exact(&mut buffer) - .await - .map_err(ConnectionError::Socket)?; - Bytes::from(buffer) - }; + let body = { + let mut buffer = vec![0u8; len as usize]; + stream + .read_exact(&mut buffer) + .await + .map_err(ConnectionError::Socket)?; + Bytes::from(buffer) + }; - match tag { - b'Q' => Ok(Some(FeMessage::Query(body))), - b'P' => Ok(Some(FeParseMessage::parse(body)?)), - b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), - b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), - b'B' => Ok(Some(FeBindMessage::parse(body)?)), - b'C' => Ok(Some(FeCloseMessage::parse(body)?)), - b'S' => Ok(Some(FeMessage::Sync)), - b'X' => Ok(Some(FeMessage::Terminate)), - b'd' => Ok(Some(FeMessage::CopyData(body))), - b'c' => Ok(Some(FeMessage::CopyDone)), - b'f' => Ok(Some(FeMessage::CopyFail)), - b'p' => Ok(Some(FeMessage::PasswordMessage(body))), - tag => { - return Err(ConnectionError::Protocol(format!( - "unknown message tag: {tag},'{body:?}'" - ))) - } + match tag { + b'Q' => Ok(Some(FeMessage::Query(body))), + b'P' => Ok(Some(FeParseMessage::parse(body)?)), + b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), + b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), + b'B' => Ok(Some(FeBindMessage::parse(body)?)), + b'C' => Ok(Some(FeCloseMessage::parse(body)?)), + b'S' => Ok(Some(FeMessage::Sync)), + b'X' => Ok(Some(FeMessage::Terminate)), + b'd' => Ok(Some(FeMessage::CopyData(body))), + b'c' => Ok(Some(FeMessage::CopyDone)), + b'f' => Ok(Some(FeMessage::CopyFail)), + b'p' => Ok(Some(FeMessage::PasswordMessage(body))), + tag => { + return Err(ConnectionError::Protocol(format!( + "unknown message tag: {tag},'{body:?}'" + ))) } - }) + } } } @@ -326,18 +292,7 @@ impl FeStartupPacket { /// Read startup message from the stream. // XXX: It's tempting yet undesirable to accept `stream` by value, // since such a change will cause user-supplied &mut references to be consumed - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { - Self::read_fut(&mut AsyncishRead(stream)).wait() - } - - /// Read startup message from the stream. - // XXX: It's tempting yet undesirable to accept `stream` by value, - // since such a change will cause user-supplied &mut references to be consumed - pub fn read_fut( - stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> + pub async fn read(stream: &mut Reader) -> Result, ConnectionError> where Reader: tokio::io::AsyncRead + Unpin, { @@ -347,99 +302,96 @@ impl FeStartupPacket { const NEGOTIATE_SSL_CODE: u32 = 5679; const NEGOTIATE_GSS_CODE: u32 = 5680; - SyncFuture::new(async move { - // Read length. If the connection is closed before reading anything (or before - // reading 4 bytes, to be precise), return None to indicate that the connection - // was closed. This matches the PostgreSQL server's behavior, which avoids noise - // in the log if the client opens connection but closes it immediately. - let len = match retry_read!(stream.read_u32().await) { - Ok(len) => len as usize, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + // Read length. If the connection is closed before reading anything (or before + // reading 4 bytes, to be precise), return None to indicate that the connection + // was closed. This matches the PostgreSQL server's behavior, which avoids noise + // in the log if the client opens connection but closes it immediately. + let len = match retry_read!(stream.read_u32().await) { + Ok(len) => len as usize, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ConnectionError::Socket(e)), + }; - #[allow(clippy::manual_range_contains)] - if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { + #[allow(clippy::manual_range_contains)] + if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { + return Err(ConnectionError::Protocol(format!( + "invalid message length {len}" + ))); + } + + let request_code = retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?; + + // the rest of startup packet are params + let params_len = len - 8; + let mut params_bytes = vec![0u8; params_len]; + stream + .read_exact(params_bytes.as_mut()) + .await + .map_err(ConnectionError::Socket)?; + + // Parse params depending on request code + let req_hi = request_code >> 16; + let req_lo = request_code & ((1 << 16) - 1); + let message = match (req_hi, req_lo) { + (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { + if params_len != 8 { + return Err(ConnectionError::Protocol( + "expected 8 bytes for CancelRequest params".to_string(), + )); + } + let mut cursor = Cursor::new(params_bytes); + FeStartupPacket::CancelRequest(CancelKeyData { + backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, + cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, + }) + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { + // Requested upgrade to SSL (aka TLS) + FeStartupPacket::SslRequest + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { + // Requested upgrade to GSSAPI + FeStartupPacket::GssEncRequest + } + (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { return Err(ConnectionError::Protocol(format!( - "invalid message length {len}" + "Unrecognized request code {unrecognized_code}" ))); } + // TODO bail if protocol major_version is not 3? + (major_version, minor_version) => { + // Parse pairs of null-terminated strings (key, value). + // See `postgres: ProcessStartupPacket, build_startup_packet`. + let mut tokens = str::from_utf8(¶ms_bytes) + .context("StartupMessage params: invalid utf-8")? + .strip_suffix('\0') // drop packet's own null + .ok_or_else(|| { + ConnectionError::Protocol( + "StartupMessage params: missing null terminator".to_string(), + ) + })? + .split_terminator('\0'); - let request_code = - retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?; + let mut params = HashMap::new(); + while let Some(name) = tokens.next() { + let value = tokens.next().ok_or_else(|| { + ConnectionError::Protocol( + "StartupMessage params: key without value".to_string(), + ) + })?; - // the rest of startup packet are params - let params_len = len - 8; - let mut params_bytes = vec![0u8; params_len]; - stream - .read_exact(params_bytes.as_mut()) - .await - .map_err(ConnectionError::Socket)?; - - // Parse params depending on request code - let req_hi = request_code >> 16; - let req_lo = request_code & ((1 << 16) - 1); - let message = match (req_hi, req_lo) { - (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { - if params_len != 8 { - return Err(ConnectionError::Protocol( - "expected 8 bytes for CancelRequest params".to_string(), - )); - } - let mut cursor = Cursor::new(params_bytes); - FeStartupPacket::CancelRequest(CancelKeyData { - backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - }) + params.insert(name.to_owned(), value.to_owned()); } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { - // Requested upgrade to SSL (aka TLS) - FeStartupPacket::SslRequest - } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { - // Requested upgrade to GSSAPI - FeStartupPacket::GssEncRequest - } - (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { - return Err(ConnectionError::Protocol(format!( - "Unrecognized request code {unrecognized_code}" - ))); - } - // TODO bail if protocol major_version is not 3? - (major_version, minor_version) => { - // Parse pairs of null-terminated strings (key, value). - // See `postgres: ProcessStartupPacket, build_startup_packet`. - let mut tokens = str::from_utf8(¶ms_bytes) - .context("StartupMessage params: invalid utf-8")? - .strip_suffix('\0') // drop packet's own null - .ok_or_else(|| { - ConnectionError::Protocol( - "StartupMessage params: missing null terminator".to_string(), - ) - })? - .split_terminator('\0'); - let mut params = HashMap::new(); - while let Some(name) = tokens.next() { - let value = tokens.next().ok_or_else(|| { - ConnectionError::Protocol( - "StartupMessage params: key without value".to_string(), - ) - })?; - - params.insert(name.to_owned(), value.to_owned()); - } - - FeStartupPacket::StartupMessage { - major_version, - minor_version, - params: StartupMessageParams { params }, - } + FeStartupPacket::StartupMessage { + major_version, + minor_version, + params: StartupMessageParams { params }, } - }; + } + }; - Ok(Some(FeMessage::StartupPacket(message))) - }) + Ok(Some(FeMessage::StartupPacket(message))) } } @@ -559,6 +511,11 @@ impl<'a> BeMessage<'a> { value: b"UTF8", }; + pub const INTEGER_DATETIMES: Self = Self::ParameterStatus { + name: b"integer_datetimes", + value: b"on", + }; + /// Build a [`BeMessage::ParameterStatus`] holding the server version. pub fn server_version(version: &'a str) -> Self { Self::ParameterStatus { @@ -698,6 +655,7 @@ fn read_cstr(buf: &mut Bytes) -> anyhow::Result { } pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000"; +pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000"; impl<'a> BeMessage<'a> { /// Write message to the given buf. @@ -1149,15 +1107,6 @@ mod tests { let params = make_params("foo\\ bar \\ \\\\ baz\\ lol"); assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]); } - - // Make sure that `read` is sync/async callable - async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) { - let _ = FeMessage::read(&mut [].as_ref()); - let _ = FeMessage::read_fut(stream).await; - - let _ = FeStartupPacket::read(&mut [].as_ref()); - let _ = FeStartupPacket::read_fut(stream).await; - } } fn terminate_code(code: &[u8; 5]) -> [u8; 6] { diff --git a/libs/pq_proto/src/sync.rs b/libs/pq_proto/src/sync.rs deleted file mode 100644 index b7ff1fb70b..0000000000 --- a/libs/pq_proto/src/sync.rs +++ /dev/null @@ -1,179 +0,0 @@ -use pin_project_lite::pin_project; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::{io, task}; - -pin_project! { - /// We use this future to mark certain methods - /// as callable in both sync and async modes. - #[repr(transparent)] - pub struct SyncFuture { - #[pin] - inner: T, - _marker: PhantomData, - } -} - -/// This wrapper lets us synchronously wait for inner future's completion -/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**. -/// For instance, `S` may be substituted with types implementing -/// [`tokio::io::AsyncRead`], but it's not the only viable option. -impl SyncFuture { - /// NOTE: caller should carefully pick a type for `S`, - /// because we don't want to enable [`SyncFuture::wait`] when - /// it's in fact impossible to run the future synchronously. - /// Violation of this contract will not cause UB, but - /// panics and async event loop freezes won't please you. - /// - /// Example: - /// - /// ``` - /// # use pq_proto::sync::SyncFuture; - /// # use std::future::Future; - /// # use tokio::io::AsyncReadExt; - /// # - /// // Parse a pair of numbers from a stream - /// pub fn parse_pair( - /// stream: &mut Reader, - /// ) -> SyncFuture> + '_> - /// where - /// Reader: tokio::io::AsyncRead + Unpin, - /// { - /// // If `Reader` is a `SyncProof`, this will give caller - /// // an opportunity to use `SyncFuture::wait`, because - /// // `.await` will always result in `Poll::Ready`. - /// SyncFuture::new(async move { - /// let x = stream.read_u32().await?; - /// let y = stream.read_u64().await?; - /// Ok((x, y)) - /// }) - /// } - /// ``` - pub fn new(inner: T) -> Self { - Self { - inner, - _marker: PhantomData, - } - } -} - -impl Future for SyncFuture { - type Output = T::Output; - - /// In async code, [`SyncFuture`] behaves like a regular wrapper. - #[inline(always)] - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { - self.project().inner.poll(cx) - } -} - -/// Postulates that we can call [`SyncFuture::wait`]. -/// If implementer is also a [`Future`], it should always -/// return [`task::Poll::Ready`] from [`Future::poll`]. -/// -/// Each implementation should document which futures -/// specifically are being declared sync-proof. -pub trait SyncPostulate {} - -impl SyncPostulate for &T {} -impl SyncPostulate for &mut T {} - -impl SyncFuture { - /// Synchronously wait for future completion. - pub fn wait(mut self) -> T::Output { - const RAW_WAKER: task::RawWaker = task::RawWaker::new( - std::ptr::null(), - &task::RawWakerVTable::new( - |_| RAW_WAKER, - |_| panic!("SyncFuture: failed to wake"), - |_| panic!("SyncFuture: failed to wake by ref"), - |_| { /* drop is no-op */ }, - ), - ); - - // SAFETY: We never move `self` during this call; - // furthermore, it will be dropped in the end regardless of panics - let this = unsafe { Pin::new_unchecked(&mut self) }; - - // SAFETY: This waker doesn't do anything apart from panicking - let waker = unsafe { task::Waker::from_raw(RAW_WAKER) }; - let context = &mut task::Context::from_waker(&waker); - - match this.poll(context) { - task::Poll::Ready(res) => res, - _ => panic!("SyncFuture: unexpected pending!"), - } - } -} - -/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`], -/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`]. -/// NOTE: you **should not** use this in async code. -#[repr(transparent)] -pub struct AsyncishRead(pub T); - -/// This lets us call [`SyncFuture, _>::wait`], -/// and allows the future to await on any of the [`AsyncRead`] -/// and [`AsyncReadExt`] methods on `AsyncishRead`. -impl SyncPostulate for AsyncishRead {} - -impl tokio::io::AsyncRead for AsyncishRead { - #[inline(always)] - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> task::Poll> { - task::Poll::Ready( - // `Read::read` will block, meaning we don't need a real event loop! - self.0 - .read(buf.initialize_unfilled()) - .map(|sz| buf.advance(sz)), - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - // async helper(stream: &mut impl AsyncRead) -> io::Result - fn bytes_add( - stream: &mut Reader, - ) -> SyncFuture> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { - SyncFuture::new(async move { - let a = stream.read_u32().await?; - let b = stream.read_u32().await?; - Ok(a + b) - }) - } - - #[test] - fn test_sync() { - let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat(); - let res = bytes_add(&mut AsyncishRead(&mut &bytes[..])) - .wait() - .unwrap(); - assert_eq!(res, 300); - } - - // We need a single-threaded executor for this test - #[tokio::test(flavor = "current_thread")] - async fn test_async() { - let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap(); - - let write = async move { - tx.write_u32(100).await?; - tx.write_u32(200).await?; - Ok(()) - }; - - let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap(); - assert_eq!(res, 300); - } -} diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 1091a8bd5c..901f849801 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static { } pub struct Download { - pub download_stream: Pin>, + pub download_stream: Pin>, /// Extra key-value data, associated with the current remote file. pub metadata: Option, } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 206e40fce9..b24de57f99 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -19,8 +19,6 @@ jsonwebtoken.workspace = true nix.workspace = true once_cell.workspace = true routerify.workspace = true -rustls.workspace = true -rustls-split.workspace = true serde.workspace = true serde_json.workspace = true signal-hook.workspace = true @@ -36,7 +34,6 @@ url.workspace = true uuid = { version = "1.2", features = ["v4", "serde"] } metrics.workspace = true -pq_proto.workspace = true workspace_hack.workspace = true [dev-dependencies] @@ -44,7 +41,6 @@ byteorder.workspace = true bytes.workspace = true criterion.workspace = true hex-literal.workspace = true -rustls-pemfile.workspace = true tempfile.workspace = true [[bench]] diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 7408eb66cd..acb5273943 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -13,7 +13,6 @@ pub mod simple_rcu; pub mod vec_map; pub mod bin_ser; -pub mod postgres_backend; // helper functions for creating and fsyncing pub mod crashsafe; @@ -26,9 +25,6 @@ pub mod id; // http endpoint utils pub mod http; -// socket splitting utils -pub mod sock_split; - // common log initialisation routine pub mod logging; diff --git a/libs/utils/src/postgres_backend.rs b/libs/utils/src/postgres_backend.rs deleted file mode 100644 index fc49aa6696..0000000000 --- a/libs/utils/src/postgres_backend.rs +++ /dev/null @@ -1,544 +0,0 @@ -//! Server-side synchronous Postgres connection, as limited as we need. -//! To use, create PostgresBackend and run() it, passing the Handler -//! implementation determining how to process the queries. Currently its API -//! is rather narrow, but we can extend it once required. - -use crate::sock_split::{BidiStream, ReadStream, WriteStream}; -use anyhow::Context; -use bytes::{Bytes, BytesMut}; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR}; -use serde::{Deserialize, Serialize}; -use std::fmt; -use std::io::{self, Write}; -use std::net::{Shutdown, SocketAddr, TcpStream}; -use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; -use tracing::*; - -pub fn is_expected_io_error(e: &io::Error) -> bool { - use io::ErrorKind::*; - matches!( - e.kind(), - ConnectionRefused | ConnectionAborted | ConnectionReset - ) -} - -/// An error, occurred during query processing: -/// either during the connection ([`ConnectionError`]) or before/after it. -#[derive(thiserror::Error, Debug)] -pub enum QueryError { - /// The connection was lost while processing the query. - #[error(transparent)] - Disconnected(#[from] ConnectionError), - /// Some other error - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -impl From for QueryError { - fn from(e: io::Error) -> Self { - Self::Disconnected(ConnectionError::Socket(e)) - } -} - -impl QueryError { - pub fn pg_error_code(&self) -> &'static [u8; 5] { - match self { - Self::Disconnected(_) => b"08006", // connection failure - Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error - } - } -} - -pub trait Handler { - /// Handle single query. - /// postgres_backend will issue ReadyForQuery after calling this (this - /// might be not what we want after CopyData streaming, but currently we don't - /// care). - fn process_query( - &mut self, - pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError>; - - /// Called on startup packet receival, allows to process params. - /// - /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users - /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow - /// to override whole init logic in implementations. - fn startup( - &mut self, - _pgb: &mut PostgresBackend, - _sm: &FeStartupPacket, - ) -> Result<(), QueryError> { - Ok(()) - } - - /// Check auth jwt - fn check_auth_jwt( - &mut self, - _pgb: &mut PostgresBackend, - _jwt_response: &[u8], - ) -> Result<(), QueryError> { - Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) - } - - fn is_shutdown_requested(&self) -> bool { - false - } -} - -/// PostgresBackend protocol state. -/// XXX: The order of the constructors matters. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum ProtoState { - Initialization, - Encrypted, - Authentication, - Established, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] -pub enum AuthType { - Trust, - // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT - NeonJWT, -} - -impl FromStr for AuthType { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s { - "Trust" => Ok(Self::Trust), - "NeonJWT" => Ok(Self::NeonJWT), - _ => anyhow::bail!("invalid value \"{s}\" for auth type"), - } - } -} - -impl fmt::Display for AuthType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - AuthType::Trust => "Trust", - AuthType::NeonJWT => "NeonJWT", - }) - } -} - -#[derive(Clone, Copy)] -pub enum ProcessMsgResult { - Continue, - Break, -} - -/// Always-writeable sock_split stream. -/// May not be readable. See [`PostgresBackend::take_stream_in`] -pub enum Stream { - Bidirectional(BidiStream), - WriteOnly(WriteStream), -} - -impl Stream { - fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how), - Self::WriteOnly(write_stream) => write_stream.shutdown(how), - } - } -} - -impl io::Write for Stream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.write(buf), - Self::WriteOnly(write_stream) => write_stream.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.flush(), - Self::WriteOnly(write_stream) => write_stream.flush(), - } - } -} - -pub struct PostgresBackend { - stream: Option, - // Output buffer. c.f. BeMessage::write why we are using BytesMut here. - buf_out: BytesMut, - - pub state: ProtoState, - - auth_type: AuthType, - - peer_addr: SocketAddr, - pub tls_config: Option>, -} - -pub fn query_from_cstring(query_string: Bytes) -> Vec { - let mut query_string = query_string.to_vec(); - if let Some(ch) = query_string.last() { - if *ch == 0 { - query_string.pop(); - } - } - query_string -} - -// Helper function for socket read loops -pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool { - for cause in error.chain() { - if let Some(io_error) = cause.downcast_ref::() { - if io_error.kind() == std::io::ErrorKind::WouldBlock { - return true; - } - } - } - false -} - -// Cast a byte slice to a string slice, dropping null terminator if there's one. -fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { - let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); - std::str::from_utf8(without_null).map_err(|e| e.into()) -} - -impl PostgresBackend { - pub fn new( - socket: TcpStream, - auth_type: AuthType, - tls_config: Option>, - set_read_timeout: bool, - ) -> io::Result { - let peer_addr = socket.peer_addr()?; - if set_read_timeout { - socket - .set_read_timeout(Some(Duration::from_secs(5))) - .unwrap(); - } - - Ok(Self { - stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))), - buf_out: BytesMut::with_capacity(10 * 1024), - state: ProtoState::Initialization, - auth_type, - tls_config, - peer_addr, - }) - } - - pub fn into_stream(self) -> Stream { - self.stream.unwrap() - } - - /// Get direct reference (into the Option) to the read stream. - fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> { - match &mut self.stream { - Some(Stream::Bidirectional(stream)) => Ok(stream), - _ => anyhow::bail!("reader taken"), - } - } - - pub fn get_peer_addr(&self) -> &SocketAddr { - &self.peer_addr - } - - pub fn take_stream_in(&mut self) -> Option { - let stream = self.stream.take(); - match stream { - Some(Stream::Bidirectional(bidi_stream)) => { - let (read, write) = bidi_stream.split(); - self.stream = Some(Stream::WriteOnly(write)); - Some(read) - } - stream => { - self.stream = stream; - None - } - } - } - - /// Read full message or return None if connection is closed. - pub fn read_message(&mut self) -> Result, QueryError> { - let (state, stream) = (self.state, self.get_stream_in()?); - - use ProtoState::*; - match state { - Initialization | Encrypted => FeStartupPacket::read(stream), - Authentication | Established => FeMessage::read(stream), - } - .map_err(QueryError::from) - } - - /// Write message into internal output buffer. - pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buf_out, message)?; - Ok(self) - } - - /// Flush output buffer into the socket. - pub fn flush(&mut self) -> io::Result<&mut Self> { - let stream = self.stream.as_mut().unwrap(); - stream.write_all(&self.buf_out)?; - self.buf_out.clear(); - Ok(self) - } - - /// Write message into internal buffer and flush it. - pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush() - } - - // Wrapper for run_message_loop() that shuts down socket when we are done - pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> { - let ret = self.run_message_loop(handler); - if let Some(stream) = self.stream.as_mut() { - let _ = stream.shutdown(Shutdown::Both); - } - ret - } - - fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> { - trace!("postgres backend to {:?} started", self.peer_addr); - - let mut unnamed_query_string = Bytes::new(); - - while !handler.is_shutdown_requested() { - match self.read_message() { - Ok(message) => { - if let Some(msg) = message { - trace!("got message {msg:?}"); - - match self.process_message(handler, msg, &mut unnamed_query_string)? { - ProcessMsgResult::Continue => continue, - ProcessMsgResult::Break => break, - } - } else { - break; - } - } - Err(e) => { - if let QueryError::Other(e) = &e { - if is_socket_read_timed_out(e) { - continue; - } - } - return Err(e); - } - } - } - - trace!("postgres backend to {:?} exited", self.peer_addr); - Ok(()) - } - - pub fn start_tls(&mut self) -> anyhow::Result<()> { - match self.stream.take() { - Some(Stream::Bidirectional(bidi_stream)) => { - let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?; - self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?)); - Ok(()) - } - stream => { - self.stream = stream; - anyhow::bail!("can't start TLs without bidi stream"); - } - } - } - - fn process_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - unnamed_query_string: &mut Bytes, - ) -> Result { - // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth - // TODO: change that to proper top-level match of protocol state with separate message handling for each state - if self.state < ProtoState::Established - && !matches!( - msg, - FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_) - ) - { - return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); - } - - let have_tls = self.tls_config.is_some(); - match msg { - FeMessage::StartupPacket(m) => { - trace!("got startup message {m:?}"); - - match m { - FeStartupPacket::SslRequest => { - debug!("SSL requested"); - - self.write_message(&BeMessage::EncryptionResponse(have_tls))?; - if have_tls { - self.start_tls()?; - self.state = ProtoState::Encrypted; - } - } - FeStartupPacket::GssEncRequest => { - debug!("GSS requested"); - self.write_message(&BeMessage::EncryptionResponse(false))?; - } - FeStartupPacket::StartupMessage { .. } => { - if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message(&BeMessage::ErrorResponse( - "must connect with TLS", - None, - ))?; - return Err(QueryError::Other(anyhow::anyhow!( - "client did not connect with TLS" - ))); - } - - // NB: startup() may change self.auth_type -- we are using that in proxy code - // to bypass auth for new users. - handler.startup(self, &m)?; - - match self.auth_type { - AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - // The async python driver requires a valid server_version - .write_message_noflush(&BeMessage::server_version("14.1"))? - .write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - AuthType::NeonJWT => { - self.write_message(&BeMessage::AuthenticationCleartextPassword)?; - self.state = ProtoState::Authentication; - } - } - } - FeStartupPacket::CancelRequest { .. } => { - return Ok(ProcessMsgResult::Break); - } - } - } - - FeMessage::PasswordMessage(m) => { - trace!("got password message '{:?}'", m); - - assert!(self.state == ProtoState::Authentication); - - match self.auth_type { - AuthType::Trust => unreachable!(), - AuthType::NeonJWT => { - let (_, jwt_response) = m.split_last().context("protocol violation")?; - - if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - return Err(e); - } - } - } - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - .write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - - FeMessage::Query(body) => { - // remove null terminator - let query_string = cstr_to_str(&body)?; - - trace!("got query {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string) { - log_query_error(query_string, &e); - let short_error = short_error(&e); - self.write_message_noflush(&BeMessage::ErrorResponse( - &short_error, - Some(e.pg_error_code()), - ))?; - } - self.write_message(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Parse(m) => { - *unnamed_query_string = m.query_string; - self.write_message(&BeMessage::ParseComplete)?; - } - - FeMessage::Describe(_) => { - self.write_message_noflush(&BeMessage::ParameterDescription)? - .write_message(&BeMessage::NoData)?; - } - - FeMessage::Bind(_) => { - self.write_message(&BeMessage::BindComplete)?; - } - - FeMessage::Close(_) => { - self.write_message(&BeMessage::CloseComplete)?; - } - - FeMessage::Execute(_) => { - let query_string = cstr_to_str(unnamed_query_string)?; - trace!("got execute {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string) { - log_query_error(query_string, &e); - self.write_message(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - } - // NOTE there is no ReadyForQuery message. This handler is used - // for basebackup and it uses CopyOut which doesn't require - // ReadyForQuery message and backend just switches back to - // processing mode after sending CopyDone or ErrorResponse. - } - - FeMessage::Sync => { - self.write_message(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Terminate => { - return Ok(ProcessMsgResult::Break); - } - - // We prefer explicit pattern matching to wildcards, because - // this helps us spot the places where new variants are missing - FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message type: {msg:?}" - ))); - } - } - - Ok(ProcessMsgResult::Continue) - } -} - -pub fn short_error(e: &QueryError) -> String { - match e { - QueryError::Disconnected(connection_error) => connection_error.to_string(), - QueryError::Other(e) => format!("{e:#}"), - } -} - -pub(super) fn log_query_error(query: &str, e: &QueryError) { - match e { - QueryError::Disconnected(ConnectionError::Socket(io_error)) => { - if is_expected_io_error(io_error) { - info!("query handler for '{query}' failed with expected io error: {io_error}"); - } else { - error!("query handler for '{query}' failed with io error: {io_error}"); - } - } - QueryError::Disconnected(other_connection_error) => { - error!("query handler for '{query}' failed with connection error: {other_connection_error:?}") - } - QueryError::Other(e) => { - error!("query handler for '{query}' failed: {e:?}"); - } - } -} diff --git a/libs/utils/src/sock_split.rs b/libs/utils/src/sock_split.rs deleted file mode 100644 index b0e5a0bf6a..0000000000 --- a/libs/utils/src/sock_split.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{ - io::{self, BufReader, Write}, - net::{Shutdown, TcpStream}, - sync::Arc, -}; - -use rustls::Connection; - -/// Wrapper supporting reads of a shared TcpStream. -pub struct ArcTcpRead(Arc); - -impl io::Read for ArcTcpRead { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - (&*self.0).read(buf) - } -} - -impl std::ops::Deref for ArcTcpRead { - type Target = TcpStream; - - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -/// Wrapper around a TCP Stream supporting buffered reads. -pub struct BufStream(BufReader); - -impl io::Read for BufStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } -} - -impl io::Write for BufStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.get_ref().write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.get_ref().flush() - } -} - -impl BufStream { - /// Unwrap into the internal BufReader. - fn into_reader(self) -> BufReader { - self.0 - } - - /// Returns a reference to the underlying TcpStream. - fn get_ref(&self) -> &TcpStream { - &self.0.get_ref().0 - } -} - -pub enum ReadStream { - Tcp(BufReader), - Tls(rustls_split::ReadHalf), -} - -impl io::Read for ReadStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - Self::Tcp(reader) => reader.read(buf), - Self::Tls(read_half) => read_half.read(buf), - } - } -} - -impl ReadStream { - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.get_ref().shutdown(how), - Self::Tls(write_half) => write_half.shutdown(how), - } - } -} - -pub enum WriteStream { - Tcp(Arc), - Tls(rustls_split::WriteHalf), -} - -impl WriteStream { - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.shutdown(how), - Self::Tls(write_half) => write_half.shutdown(how), - } - } -} - -impl io::Write for WriteStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.as_ref().write(buf), - Self::Tls(write_half) => write_half.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.as_ref().flush(), - Self::Tls(write_half) => write_half.flush(), - } - } -} - -type TlsStream = rustls::StreamOwned; - -pub enum BidiStream { - Tcp(BufStream), - /// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`]. - Tls(Box>), -} - -impl BidiStream { - pub fn from_tcp(stream: TcpStream) -> Self { - Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream))))) - } - - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.get_ref().shutdown(how), - Self::Tls(tls_boxed) => { - if how == Shutdown::Read { - tls_boxed.sock.get_ref().shutdown(how) - } else { - tls_boxed.conn.send_close_notify(); - let res = tls_boxed.flush(); - tls_boxed.sock.get_ref().shutdown(how)?; - res - } - } - } - } - - /// Split the bi-directional stream into two owned read and write halves. - pub fn split(self) -> (ReadStream, WriteStream) { - match self { - Self::Tcp(stream) => { - let reader = stream.into_reader(); - let stream: Arc = reader.get_ref().0.clone(); - - (ReadStream::Tcp(reader), WriteStream::Tcp(stream)) - } - Self::Tls(tls_boxed) => { - let reader = tls_boxed.sock.into_reader(); - let buffer_data = reader.buffer().to_owned(); - let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192); - let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192); - - // TODO would be nice to avoid the Arc here - let socket = Arc::try_unwrap(reader.into_inner().0).unwrap(); - - let (read_half, write_half) = rustls_split::split( - socket, - Connection::Server(tls_boxed.conn), - read_buf_cfg, - write_buf_cfg, - ); - (ReadStream::Tls(read_half), WriteStream::Tls(write_half)) - } - } - } - - pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result { - match self { - Self::Tcp(mut stream) => { - conn.complete_io(&mut stream)?; - assert!(!conn.is_handshaking()); - Ok(Self::Tls(Box::new(TlsStream::new(conn, stream)))) - } - Self::Tls { .. } => Err(io::Error::new( - io::ErrorKind::InvalidInput, - "TLS is already started on this stream", - )), - } - } -} - -impl io::Read for BidiStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.read(buf), - Self::Tls(tls_boxed) => tls_boxed.read(buf), - } - } -} - -impl io::Write for BidiStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.write(buf), - Self::Tls(tls_boxed) => tls_boxed.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.flush(), - Self::Tls(tls_boxed) => tls_boxed.flush(), - } - } -} diff --git a/libs/utils/tests/ssl_test.rs b/libs/utils/tests/ssl_test.rs deleted file mode 100644 index bf09f1b37d..0000000000 --- a/libs/utils/tests/ssl_test.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::{ - collections::HashMap, - io::{Cursor, Read, Write}, - net::{TcpListener, TcpStream}, - sync::Arc, -}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use once_cell::sync::Lazy; - -use utils::{ - postgres_backend::QueryError, - postgres_backend::{AuthType, Handler, PostgresBackend}, -}; - -fn make_tcp_pair() -> (TcpStream, TcpStream) { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - let client_stream = TcpStream::connect(addr).unwrap(); - let (server_stream, _) = listener.accept().unwrap(); - (server_stream, client_stream) -} - -static KEY: Lazy = Lazy::new(|| { - let mut cursor = Cursor::new(include_bytes!("key.pem")); - rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()) -}); - -static CERT: Lazy = Lazy::new(|| { - let mut cursor = Cursor::new(include_bytes!("cert.pem")); - rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) -}); - -#[test] -// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274), -// we resize the vector so doing some modifications after all -#[allow(clippy::read_zero_byte_vec)] -fn ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - const QUERY: &str = "hello world"; - - let client_jh = std::thread::spawn(move || { - // SSLRequest - client_sock.write_u32::(8).unwrap(); - client_sock.write_u32::(80877103).unwrap(); - - let ssl_response = client_sock.read_u8().unwrap(); - assert_eq!(b'S', ssl_response); - - let cfg = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates({ - let mut store = rustls::RootCertStore::empty(); - store.add(&CERT).unwrap(); - store - }) - .with_no_client_auth(); - let client_config = Arc::new(cfg); - - let dns_name = "localhost".try_into().unwrap(); - let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap(); - - conn.complete_io(&mut client_sock).unwrap(); - assert!(!conn.is_handshaking()); - - let mut stream = rustls::Stream::new(&mut conn, &mut client_sock); - - // StartupMessage - stream.write_u32::(9).unwrap(); - stream.write_u32::(196608).unwrap(); - stream.write_u8(0).unwrap(); - stream.flush().unwrap(); - - // wait for ReadyForQuery - let mut msg_buf = Vec::new(); - loop { - let msg = stream.read_u8().unwrap(); - let size = stream.read_u32::().unwrap() - 4; - msg_buf.resize(size as usize, 0); - stream.read_exact(&mut msg_buf).unwrap(); - - if msg == b'Z' { - // ReadyForQuery - break; - } - } - - // Query - stream.write_u8(b'Q').unwrap(); - stream - .write_u32::(4u32 + QUERY.len() as u32) - .unwrap(); - stream.write_all(QUERY.as_ref()).unwrap(); - stream.flush().unwrap(); - - // ReadyForQuery - let msg = stream.read_u8().unwrap(); - assert_eq!(msg, b'Z'); - }); - - struct TestHandler { - got_query: bool, - } - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError> { - self.got_query = query_string == QUERY; - Ok(()) - } - } - let mut handler = TestHandler { got_query: false }; - - let cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![CERT.clone()], KEY.clone()) - .unwrap(); - let tls_config = Some(Arc::new(cfg)); - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap(); - pgb.run(&mut handler).unwrap(); - assert!(handler.got_query); - - client_jh.join().unwrap(); - - // TODO consider shutdown behavior -} - -#[test] -fn no_ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - let client_jh = std::thread::spawn(move || { - let mut buf = BytesMut::new(); - - // SSLRequest - buf.put_u32(8); - buf.put_u32(80877103); - client_sock.write_all(&buf).unwrap(); - buf.clear(); - - let ssl_response = client_sock.read_u8().unwrap(); - assert_eq!(b'N', ssl_response); - }); - - struct TestHandler; - - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - _query_string: &str, - ) -> Result<(), QueryError> { - panic!() - } - } - - let mut handler = TestHandler; - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap(); - pgb.run(&mut handler).unwrap(); - - client_jh.join().unwrap(); -} - -#[test] -fn server_forces_ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - let client_jh = std::thread::spawn(move || { - // StartupMessage - client_sock.write_u32::(9).unwrap(); - client_sock.write_u32::(196608).unwrap(); - client_sock.write_u8(0).unwrap(); - client_sock.flush().unwrap(); - - // ErrorResponse - assert_eq!(client_sock.read_u8().unwrap(), b'E'); - let len = client_sock.read_u32::().unwrap() - 4; - - let mut body = vec![0; len as usize]; - client_sock.read_exact(&mut body).unwrap(); - let mut body = Bytes::from(body); - - let mut errors = HashMap::new(); - loop { - let field_type = body.get_u8(); - if field_type == 0u8 { - break; - } - - let end_idx = body.iter().position(|&b| b == 0u8).unwrap(); - let mut value = body.split_to(end_idx + 1); - assert_eq!(value[end_idx], 0u8); - value.truncate(end_idx); - let old = errors.insert(field_type, value); - assert!(old.is_none()); - } - - assert!(!body.has_remaining()); - - assert_eq!("must connect with TLS", errors.get(&b'M').unwrap()); - - // TODO read failure - }); - - struct TestHandler; - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - _query_string: &str, - ) -> Result<(), QueryError> { - panic!() - } - } - let mut handler = TestHandler; - - let cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![CERT.clone()], KEY.clone()) - .unwrap(); - let tls_config = Some(Arc::new(cfg)); - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap(); - let res = pgb.run(&mut handler).unwrap_err(); - assert_eq!("client did not connect with TLS", format!("{}", res)); - - client_jh.join().unwrap(); - - // TODO consider shutdown behavior -} diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 9caab7955b..564a3de82c 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -23,11 +23,10 @@ use pageserver::{ tenant::mgr, virtual_file, }; +use postgres_backend::AuthType; use utils::{ auth::JwtAuth, - logging, - postgres_backend::AuthType, - project_git_version, + logging, project_git_version, sentry_init::init_sentry, signals::{self, Signal}, tcp_listener, diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 7442814c43..fde889d01a 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -21,10 +21,10 @@ use std::time::Duration; use toml_edit; use toml_edit::{Document, Item}; +use postgres_backend::AuthType; use utils::{ id::{NodeId, TenantId, TimelineId}, logging::LogFormat, - postgres_backend::AuthType, }; use crate::tenant::config::TenantConf; diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index dc4be9dd65..bdcd71a20f 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -20,7 +20,7 @@ use pageserver_api::models::{ PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamNblocksRequest, PagestreamNblocksResponse, }; -use postgres_backend::{self, is_expected_io_error, PostgresBackend, QueryError}; +use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError}; use pq_proto::ConnectionError; use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; @@ -36,7 +36,6 @@ use utils::{ auth::{Claims, JwtAuth, Scope}, id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, simple_rcu::RcuReadGuard, }; @@ -68,7 +67,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { msg } + msg = pgb.read_message() => { msg.map_err(QueryError::from)} }; match msg { @@ -79,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream continue, FeMessage::Terminate => { let msg = "client terminated connection with Terminate message during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; break; } m => { let msg = format!("unexpected message {m:?}"); - pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None))?; + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?; Err(io::Error::new(io::ErrorKind::Other, msg))?; break; } @@ -96,8 +97,9 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { let msg = "client closed connection during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; pgb.flush().await?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; } @@ -105,7 +107,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { - Err(io::Error::new(io::ErrorKind::Other, other))?; + Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?; } }; } diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index f9d1e819a1..41ac61b7b6 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -435,8 +435,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result> = Lazy::new(Default::default); @@ -33,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N /// Console management API listener task. /// It spawns console response handlers needed for the link auth. -pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> { +pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { scopeguard::defer! { info!("mgmt has shut down"); } @@ -42,18 +40,12 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> let (socket, peer_addr) = listener.accept().await?; info!("accepted connection from {peer_addr}"); - let socket = socket.into_std()?; socket .set_nodelay(true) .context("failed to set client socket option")?; - socket - .set_nonblocking(false) - .context("failed to set client socket option")?; - // TODO: replace with async tasks. - thread::spawn(move || { - let tid = std::thread::current().id(); - let span = info_span!("mgmt", thread = format_args!("{tid:?}")); + tokio::task::spawn(async move { + let span = info_span!("mgmt", peer = %peer_addr); let _enter = span.enter(); info!("started a new console management API thread"); @@ -61,16 +53,16 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> info!("console management API thread is about to finish"); } - if let Err(e) = handle_connection(socket) { + if let Err(e) = handle_connection(socket).await { error!("thread failed with an error: {e}"); } }); } } -fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { - let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?; - pgbackend.run(&mut MgmtHandler) +async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { + let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?; + pgbackend.run(&mut MgmtHandler, future::pending::<()>).await } /// A message received by `mgmt` when a compute node is ready. @@ -78,16 +70,21 @@ pub type ComputeReady = Result; // TODO: replace with an http-based protocol. struct MgmtHandler; +#[async_trait::async_trait] impl postgres_backend::Handler for MgmtHandler { - fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { - try_process_query(pgb, query).map_err(|e| { + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query: &str, + ) -> Result<(), QueryError> { + try_process_query(pgb, query).await.map_err(|e| { error!("failed to process response: {e:?}"); e }) } } -fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { +async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?; let span = info_span!("event", session_id = resp.session_id); @@ -98,11 +95,11 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? - .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; } Err(e) => { error!("failed to deliver response to per-client task"); - pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?; + pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 02a0fabe9a..e0cf1326b9 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -51,7 +51,7 @@ impl PqStream { /// Receive [`FeStartupPacket`], which is a first packet sent by a client. pub async fn read_startup_packet(&mut self) -> io::Result { // TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket` - let msg = FeStartupPacket::read_fut(&mut self.stream) + let msg = FeStartupPacket::read(&mut self.stream) .await .map_err(ConnectionError::into_io_error)? .ok_or_else(err_connection)?; @@ -73,7 +73,7 @@ impl PqStream { } async fn read_message(&mut self) -> io::Result { - FeMessage::read_fut(&mut self.stream) + FeMessage::read(&mut self.stream) .await .map_err(ConnectionError::into_io_error)? .ok_or_else(err_connection) diff --git a/run_clippy.sh b/run_clippy.sh index fe0e745d7d..0558541089 100755 --- a/run_clippy.sh +++ b/run_clippy.sh @@ -11,12 +11,18 @@ # Not every feature is supported in macOS builds. Avoid running regular linting # script that checks every feature. +# +# manual-range-contains wants +# !(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len) +# instead of +# len < 4 || len > MAX_STARTUP_PACKET_LENGTH +# , let's disagree. if [[ "$OSTYPE" == "darwin"* ]]; then # no extra features to test currently, add more here when needed - cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -D warnings + cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -A clippy::manual-range-contains -D warnings else # * `-A unknown_lints` – do not warn about unknown lint suppressions # that people with newer toolchains might use # * `-D warnings` - fail on any warnings (`cargo` returns non-zero exit status) - cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -D warnings + cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -A clippy::manual-range-contains -D warnings fi diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 2424509477..36ee15347d 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -36,6 +36,7 @@ toml_edit.workspace = true tracing.workspace = true url.workspace = true metrics.workspace = true +postgres_backend.workspace = true postgres_ffi.workspace = true pq_proto.workspace = true remote_storage.workspace = true diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 683050e9cd..d2cb9f79b9 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -236,7 +236,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { let conf_cloned = conf.clone(); let safekeeper_thread = thread::Builder::new() - .name("safekeeper thread".into()) + .name("WAL service thread".into()) .spawn(|| wal_service::thread_main(conf_cloned, pg_listener)) .unwrap(); diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index d1cd76459b..3e7bafbd2f 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -1,27 +1,23 @@ //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres //! protocol commands. +use anyhow::Context; +use std::str; +use tracing::{info, info_span, Instrument}; + use crate::auth::check_permission; use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage}; -use crate::receive_wal::ReceiveWalConn; - -use crate::send_wal::ReplicationConn; use crate::{GlobalTimelines, SafeKeeperConf}; -use anyhow::Context; - +use postgres_backend::QueryError; +use postgres_backend::{self, PostgresBackend}; use postgres_ffi::PG_TLI; -use regex::Regex; - use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID}; -use std::str; -use tracing::info; +use regex::Regex; use utils::auth::{Claims, Scope}; -use utils::postgres_backend::QueryError; use utils::{ id::{TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, - postgres_backend::{self, PostgresBackend}, }; /// Safekeeper handler of postgres commands @@ -53,7 +49,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { let start_lsn = caps .next() .map(|cap| cap[1].parse::()) - .context("failed to parse start LSN from START_REPLICATION command")??; + .context("parse start LSN from START_REPLICATION command")??; Ok(SafekeeperPostgresCommand::StartReplication { start_lsn }) } else if cmd.starts_with("IDENTIFY_SYSTEM") { Ok(SafekeeperPostgresCommand::IdentifySystem) @@ -67,6 +63,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { } } +#[async_trait::async_trait] impl postgres_backend::Handler for SafekeeperPostgresHandler { // tenant_id and timeline_id are passed in connection string params fn startup( @@ -137,7 +134,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { Ok(()) } - fn process_query( + async fn process_query( &mut self, pgb: &mut PostgresBackend, query_string: &str, @@ -147,9 +144,10 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { .starts_with("set datestyle to ") { // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect - pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; return Ok(()); } + let cmd = parse_cmd(query_string)?; info!( @@ -161,26 +159,23 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { let timeline_id = self.timeline_id.context("timelineid is required")?; self.check_permission(Some(tenant_id))?; self.ttid = TenantTimelineId::new(tenant_id, timeline_id); + let span_ttid = self.ttid; // satisfy borrow checker - let res = match cmd { - SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self), + match cmd { + SafekeeperPostgresCommand::StartWalPush => { + self.handle_start_wal_push(pgb) + .instrument(info_span!("WAL receiver", ttid = %span_ttid)) + .await + } SafekeeperPostgresCommand::StartReplication { start_lsn } => { - ReplicationConn::new(pgb).run(self, pgb, start_lsn) + self.handle_start_replication(pgb, start_lsn) + .instrument(info_span!("WAL sender", ttid = %span_ttid)) + .await } - SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb), - SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd), - }; - - match res { - Ok(()) => Ok(()), - Err(QueryError::Disconnected(connection_error)) => { - info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}"); - Err(QueryError::Disconnected(connection_error)) + SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await, + SafekeeperPostgresCommand::JSONCtrl { ref cmd } => { + handle_json_ctrl(self, pgb, cmd).await } - Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!( - "Failed to process query for timeline {}", - self.ttid - )))), } } } @@ -217,7 +212,10 @@ impl SafekeeperPostgresHandler { /// /// Handle IDENTIFY_SYSTEM replication command /// - fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> { + async fn handle_identify_system( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), QueryError> { let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?; let lsn = if self.is_walproposer_recovery() { @@ -267,7 +265,7 @@ impl SafekeeperPostgresHandler { Some(lsn_bytes), None, ]))? - .write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; + .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; Ok(()) } diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index b157fcb076..14badebd95 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -168,12 +168,9 @@ async fn timeline_create_handler(mut request: Request) -> Result anyhow::Result> { +async fn prepare_safekeeper( + ttid: TenantTimelineId, + pg_version: u32, +) -> anyhow::Result> { GlobalTimelines::create( ttid, ServerInfo { @@ -106,6 +110,7 @@ fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result Lsn::INVALID, Lsn::INVALID, ) + .await } fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> anyhow::Result<()> { @@ -128,15 +133,15 @@ fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> anyhow::R } #[derive(Debug, Serialize, Deserialize)] -struct InsertedWAL { +pub struct InsertedWAL { begin_lsn: Lsn, - end_lsn: Lsn, + pub end_lsn: Lsn, append_response: AppendResponse, } /// Extend local WAL with new LogicalMessage record. To do that, /// create AppendRequest with new WAL and pass it to safekeeper. -fn append_logical_message( +pub fn append_logical_message( tli: &Arc, msg: &AppendLogicalMessage, ) -> anyhow::Result { diff --git a/safekeeper/src/lib.rs b/safekeeper/src/lib.rs index 6ab108ceb0..03df546a4d 100644 --- a/safekeeper/src/lib.rs +++ b/safekeeper/src/lib.rs @@ -1,8 +1,7 @@ -use storage_broker::Uri; -// use remote_storage::RemoteStorageConfig; use std::path::PathBuf; use std::time::Duration; +use storage_broker::Uri; use utils::id::{NodeId, TenantId, TenantTimelineId}; diff --git a/safekeeper/src/receive_wal.rs b/safekeeper/src/receive_wal.rs index 0cf921d97a..22c9871026 100644 --- a/safekeeper/src/receive_wal.rs +++ b/safekeeper/src/receive_wal.rs @@ -2,204 +2,284 @@ //! Gets messages from the network, passes them down to consensus module and //! sends replies back. -use anyhow::anyhow; -use anyhow::Context; - -use bytes::BytesMut; -use tracing::*; -use utils::lsn::Lsn; -use utils::postgres_backend::QueryError; - +use crate::handler::SafekeeperPostgresHandler; +use crate::safekeeper::AcceptorProposerMessage; +use crate::safekeeper::ProposerAcceptorMessage; use crate::safekeeper::ServerInfo; use crate::timeline::Timeline; use crate::GlobalTimelines; - +use anyhow::{anyhow, Context}; +use bytes::BytesMut; +use nix::unistd::gettid; +use postgres_backend::CopyStreamHandlerEnd; +use postgres_backend::PostgresBackend; +use postgres_backend::PostgresBackendReader; +use postgres_backend::QueryError; +use pq_proto::BeMessage; use std::net::SocketAddr; -use std::sync::mpsc::channel; -use std::sync::mpsc::Receiver; - use std::sync::Arc; use std::thread; +use std::thread::JoinHandle; +use tokio::sync::mpsc::channel; +use tokio::sync::mpsc::error::TryRecvError; +use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::Sender; +use tokio::task::spawn_blocking; +use tracing::*; +use utils::id::TenantTimelineId; +use utils::lsn::Lsn; -use crate::safekeeper::AcceptorProposerMessage; -use crate::safekeeper::ProposerAcceptorMessage; +const MSG_QUEUE_SIZE: usize = 256; +const REPLY_QUEUE_SIZE: usize = 16; -use crate::handler::SafekeeperPostgresHandler; -use pq_proto::{BeMessage, FeMessage}; -use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream}; - -pub struct ReceiveWalConn<'pg> { - /// Postgres connection - pg_backend: &'pg mut PostgresBackend, - /// The cached result of `pg_backend.socket().peer_addr()` (roughly) - peer_addr: SocketAddr, -} - -impl<'pg> ReceiveWalConn<'pg> { - pub fn new(pg: &'pg mut PostgresBackend) -> ReceiveWalConn<'pg> { - let peer_addr = *pg.get_peer_addr(); - ReceiveWalConn { - pg_backend: pg, - peer_addr, +impl SafekeeperPostgresHandler { + /// Wrapper around handle_start_wal_push_guts handling result. Error is + /// handled here while we're still in walreceiver ttid span; with API + /// extension, this can probably be moved into postgres_backend. + pub async fn handle_start_wal_push( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), QueryError> { + if let Err(end) = self.handle_start_wal_push_guts(pgb).await { + // Log the result and probably send it to the client, closing the stream. + pgb.handle_copy_stream_end(end).await; } - } - - // Send message to the postgres - fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> { - let mut buf = BytesMut::with_capacity(128); - msg.serialize(&mut buf)?; - self.pg_backend.write_message(&BeMessage::CopyData(&buf))?; Ok(()) } - /// Receive WAL from wal_proposer - pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> { - let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered(); - + pub async fn handle_start_wal_push_guts( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), CopyStreamHandlerEnd> { // Notify the libpq client that it's allowed to send `CopyData` messages - self.pg_backend - .write_message(&BeMessage::CopyBothResponse)?; + pgb.write_message(&BeMessage::CopyBothResponse).await?; - let r = self - .pg_backend - .take_stream_in() - .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?; - let mut poll_reader = ProposerPollStream::new(r)?; + // Experiments [1] confirm that doing network IO in one (this) thread and + // processing with disc IO in another significantly improves + // performance; we spawn off WalAcceptor thread for message processing + // to this end. + // + // [1] https://github.com/neondatabase/neon/pull/1318 + let (msg_tx, msg_rx) = channel(MSG_QUEUE_SIZE); + let (reply_tx, reply_rx) = channel(REPLY_QUEUE_SIZE); + let mut acceptor_handle: Option>> = None; - // Receive information about server - let next_msg = poll_reader.recv_msg()?; - let tli = match next_msg { - ProposerAcceptorMessage::Greeting(ref greeting) => { - info!( - "start handshake with walproposer {} sysid {} timeline {}", - self.peer_addr, greeting.system_id, greeting.tli, - ); - let server_info = ServerInfo { - pg_version: greeting.pg_version, - system_id: greeting.system_id, - wal_seg_size: greeting.wal_seg_size, - }; - GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)? - } - _ => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message {next_msg:?} instead of greeting" - ))) - } + // Concurrently receive and send data; replies are not synchronized with + // sends, so this avoids deadlocks. + let mut pgb_reader = pgb.split().context("START_WAL_PUSH split")?; + let peer_addr = *pgb.get_peer_addr(); + let res = tokio::select! { + // todo: add read|write .context to these errors + r = read_network(self.ttid, &mut pgb_reader, peer_addr, msg_tx, &mut acceptor_handle, msg_rx, reply_tx) => r, + r = write_network(pgb, reply_rx) => r, }; - let mut next_msg = Some(next_msg); + // Join pg backend back. + pgb.unsplit(pgb_reader)?; - let mut first_time_through = true; - let mut _guard: Option = None; - loop { - if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) { - // poll AppendRequest's without blocking and write WAL to disk without flushing, - // while it's readily available - while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg { - let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); - - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - - next_msg = poll_reader.poll_msg(); - } - - // flush all written WAL to the disk - let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } else if let Some(msg) = next_msg.take() { - // process other message - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } - if first_time_through { - // Register the connection and defer unregister. Do that only - // after processing first message, as it sets wal_seg_size, - // wanted by many. - tli.on_compute_connect()?; - _guard = Some(ComputeConnectionGuard { - timeline: Arc::clone(&tli), - }); - first_time_through = false; + // Join the spawned WalAcceptor. At this point chans to/from it passed + // to network routines are dropped, so it will exit as soon as it + // touches them. + match acceptor_handle { + None => { + // failed even before spawning; read_network should have error + Err(res.expect_err("no error with WalAcceptor not spawn")) } + Some(handle) => { + let wal_acceptor_res = handle.join(); - // blocking wait for the next message - if next_msg.is_none() { - next_msg = Some(poll_reader.recv_msg()?); + // If there was any network error, return it. + res?; + + // Otherwise, WalAcceptor thread must have errored. + match wal_acceptor_res { + Ok(Ok(_)) => Ok(()), // can't happen currently; would be if we add graceful termination + Ok(Err(e)) => Err(CopyStreamHandlerEnd::Other(e.context("WAL acceptor"))), + Err(_) => Err(CopyStreamHandlerEnd::Other(anyhow!( + "WalAcceptor thread panicked", + ))), + } } } } } -struct ProposerPollStream { - msg_rx: Receiver, - read_thread: Option>>, +/// Read next message from walproposer. +/// TODO: Return Ok(None) on graceful termination. +async fn read_message( + pgb_reader: &mut PostgresBackendReader, +) -> Result { + let copy_data = pgb_reader.read_copy_message().await?; + let msg = ProposerAcceptorMessage::parse(copy_data)?; + Ok(msg) } -impl ProposerPollStream { - fn new(mut r: ReadStream) -> anyhow::Result { - let (msg_tx, msg_rx) = channel(); - - let read_thread = thread::Builder::new() - .name("Read WAL thread".into()) - .spawn(move || -> Result<(), QueryError> { - loop { - let copy_data = match FeMessage::read(&mut r)? { - Some(FeMessage::CopyData(bytes)) => Ok(bytes), - Some(msg) => Err(QueryError::Other(anyhow::anyhow!( - "expected `CopyData` message, found {msg:?}" - ))), - None => Err(QueryError::from(std::io::Error::new( - std::io::ErrorKind::ConnectionAborted, - "walproposer closed the connection", - ))), - }?; - - let msg = ProposerAcceptorMessage::parse(copy_data)?; - msg_tx - .send(msg) - .context("Failed to send the proposer message")?; - } - // msg_tx will be dropped here, this will also close msg_rx - })?; - - Ok(Self { - msg_rx, - read_thread: Some(read_thread), - }) - } - - fn recv_msg(&mut self) -> Result { - self.msg_rx.recv().map_err(|_| { - // return error from the read thread - let res = match self.read_thread.take() { - Some(thread) => thread.join(), - None => return QueryError::Other(anyhow::anyhow!("read thread is gone")), +/// Read messages from socket and pass it to WalAcceptor thread. Returns Ok(()) +/// if msg_tx closed; it must mean WalAcceptor terminated, joining it should +/// tell the error. +async fn read_network( + ttid: TenantTimelineId, + pgb_reader: &mut PostgresBackendReader, + peer_addr: SocketAddr, + msg_tx: Sender, + // WalAcceptor is spawned when we learn server info from walproposer and + // create timeline; handle is put here. + acceptor_handle: &mut Option>>, + msg_rx: Receiver, + reply_tx: Sender, +) -> Result<(), CopyStreamHandlerEnd> { + // Receive information about server to create timeline, if not yet. + let next_msg = read_message(pgb_reader).await?; + let tli = match next_msg { + ProposerAcceptorMessage::Greeting(ref greeting) => { + info!( + "start handshake with walproposer {} sysid {} timeline {}", + peer_addr, greeting.system_id, greeting.tli, + ); + let server_info = ServerInfo { + pg_version: greeting.pg_version, + system_id: greeting.system_id, + wal_seg_size: greeting.wal_seg_size, }; + GlobalTimelines::create(ttid, server_info, Lsn::INVALID, Lsn::INVALID).await? + } + _ => { + return Err(CopyStreamHandlerEnd::Other(anyhow::anyhow!( + "unexpected message {next_msg:?} instead of greeting" + ))) + } + }; - match res { - Ok(Ok(())) => { - QueryError::Other(anyhow::anyhow!("unexpected result from read thread")) - } - Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")), - Ok(Err(err)) => err, + *acceptor_handle = Some( + WalAcceptor::spawn(tli.clone(), msg_rx, reply_tx).context("spawn WalAcceptor thread")?, + ); + + // Forward all messages to WalAcceptor + read_network_loop(pgb_reader, msg_tx, next_msg).await +} + +async fn read_network_loop( + pgb_reader: &mut PostgresBackendReader, + msg_tx: Sender, + mut next_msg: ProposerAcceptorMessage, +) -> Result<(), CopyStreamHandlerEnd> { + loop { + if msg_tx.send(next_msg).await.is_err() { + return Ok(()); // chan closed, WalAcceptor terminated + } + next_msg = read_message(pgb_reader).await?; + } +} + +/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(()) +/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should +/// tell the error. +async fn write_network( + pgb_writer: &mut PostgresBackend, + mut reply_rx: Receiver, +) -> Result<(), CopyStreamHandlerEnd> { + let mut buf = BytesMut::with_capacity(128); + + loop { + match reply_rx.recv().await { + Some(msg) => { + buf.clear(); + msg.serialize(&mut buf)?; + pgb_writer.write_message(&BeMessage::CopyData(&buf)).await?; } - }) + None => return Ok(()), // chan closed, WalAcceptor terminated + } + } +} + +/// Takes messages from msg_rx, processes and pushes replies to reply_tx. +struct WalAcceptor { + tli: Arc, + msg_rx: Receiver, + reply_tx: Sender, +} + +impl WalAcceptor { + /// Spawn thread with WalAcceptor running, return handle to it. + fn spawn( + tli: Arc, + msg_rx: Receiver, + reply_tx: Sender, + ) -> anyhow::Result>> { + let thread_name = format!("WAL acceptor {}", tli.ttid); + thread::Builder::new() + .name(thread_name) + .spawn(move || -> anyhow::Result<()> { + let mut wa = WalAcceptor { + tli, + msg_rx, + reply_tx, + }; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + let span_ttid = wa.tli.ttid; // satisfy borrow checker + runtime.block_on( + wa.run() + .instrument(info_span!("WAL acceptor", tid = %gettid(), ttid = %span_ttid)), + ) + }) + .map_err(anyhow::Error::from) } - fn poll_msg(&mut self) -> Option { - let res = self.msg_rx.try_recv(); + /// The main loop. Returns Ok(()) if either msg_rx or reply_tx got closed; + /// it must mean that network thread terminated. + async fn run(&mut self) -> anyhow::Result<()> { + // Register the connection and defer unregister. + self.tli.on_compute_connect().await?; + let _guard = ComputeConnectionGuard { + timeline: Arc::clone(&self.tli), + }; - match res { - Err(_) => None, - Ok(msg) => Some(msg), + let mut next_msg: ProposerAcceptorMessage; + + loop { + let opt_msg = self.msg_rx.recv().await; + if opt_msg.is_none() { + return Ok(()); // chan closed, streaming terminated + } + next_msg = opt_msg.unwrap(); + + if matches!(next_msg, ProposerAcceptorMessage::AppendRequest(_)) { + // loop through AppendRequest's while it's readily available to + // write as many WAL as possible without fsyncing + while let ProposerAcceptorMessage::AppendRequest(append_request) = next_msg { + let noflush_msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); + + if let Some(reply) = self.tli.process_msg(&noflush_msg)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + + match self.msg_rx.try_recv() { + Ok(msg) => next_msg = msg, + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => return Ok(()), // chan closed, streaming terminated + } + } + + // flush all written WAL to the disk + if let Some(reply) = self.tli.process_msg(&ProposerAcceptorMessage::FlushWAL)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + } else { + // process message other than AppendRequest + if let Some(reply) = self.tli.process_msg(&next_msg)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + } } } } @@ -210,8 +290,13 @@ struct ComputeConnectionGuard { impl Drop for ComputeConnectionGuard { fn drop(&mut self) { - if let Err(e) = self.timeline.on_compute_disconnect() { - error!("failed to unregister compute connection: {}", e); - } + let tli = self.timeline.clone(); + // tokio forbids to call blocking_send inside the runtime, and see + // comments in on_compute_disconnect why we call blocking_send. + spawn_blocking(move || { + if let Err(e) = tli.on_compute_disconnect() { + error!("failed to unregister compute connection: {}", e); + } + }); } } diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index 7df347427e..4a046cb048 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -488,7 +488,7 @@ impl AcceptorProposerMessage { buf.put_u64_le(msg.hs_feedback.xmin); buf.put_u64_le(msg.hs_feedback.catalog_xmin); - msg.pageserver_feedback.serialize(buf)? + msg.pageserver_feedback.serialize(buf)?; } } diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 169ab03f0a..e8c1b4c02e 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -5,24 +5,22 @@ use crate::handler::SafekeeperPostgresHandler; use crate::timeline::{ReplicaState, Timeline}; use crate::wal_storage::WalReader; use crate::GlobalTimelines; -use anyhow::Context; - +use anyhow::Context as AnyhowContext; use bytes::Bytes; +use postgres_backend::PostgresBackend; +use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError}; use postgres_ffi::get_current_timestamp; use postgres_ffi::{TimestampTz, MAX_SEND_SIZE}; +use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; use serde::{Deserialize, Serialize}; use std::cmp::min; -use std::net::Shutdown; +use std::str; use std::sync::Arc; use std::time::Duration; -use std::{io, str, thread}; -use utils::postgres_backend::QueryError; - -use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; use tokio::sync::watch::Receiver; use tokio::time::timeout; use tracing::*; -use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream}; +use utils::{bin_ser::BeSer, lsn::Lsn}; // See: https://www.postgresql.org/docs/13/protocol-replication.html const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h'; @@ -60,13 +58,6 @@ pub struct StandbyReply { pub reply_requested: bool, } -/// A network connection that's speaking the replication protocol. -pub struct ReplicationConn { - /// This is an `Option` because we will spawn a background thread that will - /// `take` it from us. - stream_in: Option, -} - /// Scope guard to unregister replication connection from timeline struct ReplicationConnGuard { replica: usize, // replica internal ID assigned by timeline @@ -79,230 +70,275 @@ impl Drop for ReplicationConnGuard { } } -impl ReplicationConn { - /// Create a new `ReplicationConn` - pub fn new(pgb: &mut PostgresBackend) -> Self { - Self { - stream_in: pgb.take_stream_in(), +impl SafekeeperPostgresHandler { + /// Wrapper around handle_start_replication_guts handling result. Error is + /// handled here while we're still in walsender ttid span; with API + /// extension, this can probably be moved into postgres_backend. + pub async fn handle_start_replication( + &mut self, + pgb: &mut PostgresBackend, + start_pos: Lsn, + ) -> Result<(), QueryError> { + if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await { + // Log the result and probably send it to the client, closing the stream. + pgb.handle_copy_stream_end(end).await; } - } - - /// Handle incoming messages from the network. - /// This is spawned into the background by `handle_start_replication`. - fn background_thread( - mut stream_in: ReadStream, - replica_guard: Arc, - ) -> anyhow::Result<()> { - let replica_id = replica_guard.replica; - let timeline = &replica_guard.timeline; - - let mut state = ReplicaState::new(); - // Wait for replica's feedback. - while let Some(msg) = FeMessage::read(&mut stream_in)? { - match &msg { - FeMessage::CopyData(m) => { - // There's three possible data messages that the client is supposed to send here: - // `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`. - - match m.first().cloned() { - Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { - // Note: deserializing is on m[1..] because we skip the tag byte. - state.hs_feedback = HotStandbyFeedback::des(&m[1..]) - .context("failed to deserialize HotStandbyFeedback")?; - timeline.update_replica_state(replica_id, state); - } - Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { - let _reply = StandbyReply::des(&m[1..]) - .context("failed to deserialize StandbyReply")?; - // This must be a regular postgres replica, - // because pageserver doesn't send this type of messages to safekeeper. - // Currently this is not implemented, so this message is ignored. - - warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet."); - // timeline.update_replica_state(replica_id, Some(state)); - } - Some(NEON_STATUS_UPDATE_TAG_BYTE) => { - // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. - let buf = Bytes::copy_from_slice(&m[9..]); - let reply = ReplicationFeedback::parse(buf); - - trace!("ReplicationFeedback is {:?}", reply); - // Only pageserver sends ReplicationFeedback, so set the flag. - // This replica is the source of information to resend to compute. - state.pageserver_feedback = Some(reply); - - timeline.update_replica_state(replica_id, state); - } - _ => warn!("unexpected message {:?}", msg), - } - } - FeMessage::Sync => {} - FeMessage::CopyFail => { - // Shutdown the connection, because rust-postgres client cannot be dropped - // when connection is alive. - let _ = stream_in.shutdown(Shutdown::Both); - anyhow::bail!("Copy failed"); - } - _ => { - // We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored. - info!("unexpected message {:?}", msg); - } - } - } - Ok(()) } - /// - /// Handle START_REPLICATION replication command - /// - pub fn run( + pub async fn handle_start_replication_guts( &mut self, - spg: &mut SafekeeperPostgresHandler, pgb: &mut PostgresBackend, - mut start_pos: Lsn, - ) -> Result<(), QueryError> { - let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered(); - - let tli = GlobalTimelines::get(spg.ttid).map_err(|e| QueryError::Other(e.into()))?; - - // spawn the background thread which receives HotStandbyFeedback messages. - let bg_timeline = Arc::clone(&tli); - let bg_stream_in = self.stream_in.take().unwrap(); - let bg_timeline_id = spg.timeline_id.unwrap(); + start_pos: Lsn, + ) -> Result<(), CopyStreamHandlerEnd> { + let appname = self.appname.clone(); + let tli = + GlobalTimelines::get(self.ttid).map_err(|e| CopyStreamHandlerEnd::Other(e.into()))?; let state = ReplicaState::new(); // This replica_id is used below to check if it's time to stop replication. - let replica_id = bg_timeline.add_replica(state); + let replica_id = tli.add_replica(state); // Use a guard object to remove our entry from the timeline, when the background // thread and us have both finished using it. - let replica_guard = Arc::new(ReplicationConnGuard { + let _guard = Arc::new(ReplicationConnGuard { replica: replica_id, - timeline: bg_timeline, + timeline: tli.clone(), }); - let bg_replica_guard = Arc::clone(&replica_guard); - // TODO: here we got two threads, one for writing WAL and one for receiving - // feedback. If one of them fails, we should shutdown the other one too. - let _ = thread::Builder::new() - .name("HotStandbyFeedback thread".into()) - .spawn(move || { - let _enter = - info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered(); - if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) { - error!("Replication background thread failed: {}", err); + // Walproposer gets special handling: safekeeper must give proposer all + // local WAL till the end, whether committed or not (walproposer will + // hang otherwise). That's because walproposer runs the consensus and + // synchronizes safekeepers on the most advanced one. + // + // There is a small risk of this WAL getting concurrently garbaged if + // another compute rises which collects majority and starts fixing log + // on this safekeeper itself. That's ok as (old) proposer will never be + // able to commit such WAL. + let stop_pos: Option = if self.is_walproposer_recovery() { + let wal_end = tli.get_flush_lsn(); + Some(wal_end) + } else { + None + }; + let end_pos = stop_pos.unwrap_or(Lsn::INVALID); + + info!( + "starting streaming from {:?} till {:?}", + start_pos, stop_pos + ); + + // switch to copy + pgb.write_message(&BeMessage::CopyBothResponse).await?; + + let (_, persisted_state) = tli.get_state(); + let wal_reader = WalReader::new( + self.conf.workdir.clone(), + self.conf.timeline_dir(&tli.ttid), + &persisted_state, + start_pos, + self.conf.wal_backup_enabled, + )?; + + // Split to concurrently receive and send data; replies are generally + // not synchronized with sends, so this avoids deadlocks. + let reader = pgb.split().context("START_REPLICATION split")?; + + let mut sender = WalSender { + pgb, + tli: tli.clone(), + appname, + start_pos, + end_pos, + stop_pos, + commit_lsn_watch_rx: tli.get_commit_lsn_watch_rx(), + replica_id, + wal_reader, + send_buf: [0; MAX_SEND_SIZE], + }; + let mut reply_reader = ReplyReader { + reader, + tli, + replica_id, + feedback: ReplicaState::new(), + }; + + let res = tokio::select! { + // todo: add read|write .context to these errors + r = sender.run() => r, + r = reply_reader.run() => r, + }; + // Join pg backend back. + pgb.unsplit(reply_reader.reader)?; + + res + } +} + +/// A half driving sending WAL. +struct WalSender<'a> { + pgb: &'a mut PostgresBackend, + tli: Arc, + appname: Option, + // Position since which we are sending next chunk. + start_pos: Lsn, + // WAL up to this position is known to be locally available. + end_pos: Lsn, + // If present, terminate after reaching this position; used by walproposer + // in recovery. + stop_pos: Option, + commit_lsn_watch_rx: Receiver, + replica_id: usize, + wal_reader: WalReader, + // buffer for readling WAL into to send it + send_buf: [u8; MAX_SEND_SIZE], +} + +impl WalSender<'_> { + /// Send WAL until + /// - an error occurs + /// - if we are streaming to walproposer, we've streamed until stop_pos + /// (recovery finished) + /// - receiver is caughtup and there is no computes + /// + /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? + /// convenience. + async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + // If we are streaming to walproposer, check it is time to stop. + if let Some(stop_pos) = self.stop_pos { + if self.start_pos >= stop_pos { + // recovery finished + return Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to walproposer at {}, recovery finished", + self.start_pos + ))); } - })?; - - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - runtime.block_on(async move { - let (inmem_state, persisted_state) = tli.get_state(); - // add persisted_state.timeline_start_lsn == Lsn(0) check - - // Walproposer gets special handling: safekeeper must give proposer all - // local WAL till the end, whether committed or not (walproposer will - // hang otherwise). That's because walproposer runs the consensus and - // synchronizes safekeepers on the most advanced one. - // - // There is a small risk of this WAL getting concurrently garbaged if - // another compute rises which collects majority and starts fixing log - // on this safekeeper itself. That's ok as (old) proposer will never be - // able to commit such WAL. - let stop_pos: Option = if spg.is_walproposer_recovery() { - let wal_end = tli.get_flush_lsn(); - Some(wal_end) } else { - None - }; + // Wait for the next portion if it is not there yet, or just + // update our end of WAL available for sending value, we + // communicate it to the receiver. + self.wait_wal().await?; + } - info!("Start replication from {:?} till {:?}", start_pos, stop_pos); + // try to send as much as available, capped by MAX_SEND_SIZE + let mut send_size = self + .end_pos + .checked_sub(self.start_pos) + .context("reading wal without waiting for it first")? + .0 as usize; + send_size = min(send_size, self.send_buf.len()); + let send_buf = &mut self.send_buf[..send_size]; + // read wal into buffer + send_size = self.wal_reader.read(send_buf).await?; + let send_buf = &send_buf[..send_size]; - // switch to copy - pgb.write_message(&BeMessage::CopyBothResponse)?; - - let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn); - - let mut wal_reader = WalReader::new( - spg.conf.workdir.clone(), - spg.conf.timeline_dir(&tli.ttid), - &persisted_state, - start_pos, - spg.conf.wal_backup_enabled, - )?; - - // buffer for wal sending, limited by MAX_SEND_SIZE - let mut send_buf = vec![0u8; MAX_SEND_SIZE]; - - // watcher for commit_lsn updates - let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx(); - - loop { - if let Some(stop_pos) = stop_pos { - if start_pos >= stop_pos { - break; /* recovery finished */ - } - end_pos = stop_pos; - } else { - /* Wait until we have some data to stream */ - let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?; - - if let Some(lsn) = lsn { - end_pos = lsn; - } else { - // TODO: also check once in a while whether we are walsender - // to right pageserver. - if tli.should_walsender_stop(replica_id) { - // Shut down, timeline is suspended. - return Err(QueryError::from(io::Error::new( - io::ErrorKind::ConnectionAborted, - format!("end streaming to {:?}", spg.appname), - ))); - } - - // timeout expired: request pageserver status - pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive { - sent_ptr: end_pos.0, - timestamp: get_current_timestamp(), - request_reply: true, - }))?; - continue; - } - } - - let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; - let send_size = min(send_size, send_buf.len()); - - let send_buf = &mut send_buf[..send_size]; - - // read wal into buffer - let send_size = wal_reader.read(send_buf).await?; - let send_buf = &send_buf[..send_size]; - - // Write some data to the network socket. - pgb.write_message(&BeMessage::XLogData(XLogDataBody { - wal_start: start_pos.0, - wal_end: end_pos.0, + // and send it + self.pgb + .write_message(&BeMessage::XLogData(XLogDataBody { + wal_start: self.start_pos.0, + wal_end: self.end_pos.0, timestamp: get_current_timestamp(), data: send_buf, })) - .context("Failed to send XLogData")?; + .await?; - start_pos += send_size as u64; - trace!("sent WAL up to {}", start_pos); + trace!( + "sent {} bytes of WAL {}-{}", + send_size, + self.start_pos, + self.start_pos + send_size as u64 + ); + self.start_pos += send_size as u64; + } + } + + /// wait until we have WAL to stream, sending keepalives and checking for + /// exit in the meanwhile + async fn wait_wal(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + if let Some(lsn) = wait_for_lsn(&mut self.commit_lsn_watch_rx, self.start_pos).await? { + self.end_pos = lsn; + return Ok(()); } + // Timed out waiting for WAL, check for termination and send KA + if self.tli.should_walsender_stop(self.replica_id) { + // Terminate if there is nothing more to send. + // TODO close the stream properly + return Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to {:?} at {}, receiver is caughtup and there is no computes", + self.appname, self.start_pos, + ))); + } + self.pgb + .write_message(&BeMessage::KeepAlive(WalSndKeepAlive { + sent_ptr: self.end_pos.0, + timestamp: get_current_timestamp(), + request_reply: true, + })) + .await?; + } + } +} - Ok(()) - }) +/// A half driving receiving replies. +struct ReplyReader { + reader: PostgresBackendReader, + tli: Arc, + replica_id: usize, + feedback: ReplicaState, +} + +impl ReplyReader { + async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + let msg = self.reader.read_copy_message().await?; + self.handle_feedback(&msg)? + } + } + + fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> { + match msg.first().cloned() { + Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { + // Note: deserializing is on m[1..] because we skip the tag byte. + self.feedback.hs_feedback = HotStandbyFeedback::des(&msg[1..]) + .context("failed to deserialize HotStandbyFeedback")?; + self.tli + .update_replica_state(self.replica_id, self.feedback); + } + Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { + let _reply = + StandbyReply::des(&msg[1..]).context("failed to deserialize StandbyReply")?; + // This must be a regular postgres replica, + // because pageserver doesn't send this type of messages to safekeeper. + // Currently we just ignore this, tracking progress for them is not supported. + } + Some(NEON_STATUS_UPDATE_TAG_BYTE) => { + // pageserver sends this. + // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. + let buf = Bytes::copy_from_slice(&msg[9..]); + let reply = ReplicationFeedback::parse(buf); + + trace!("ReplicationFeedback is {:?}", reply); + // Only pageserver sends ReplicationFeedback, so set the flag. + // This replica is the source of information to resend to compute. + self.feedback.pageserver_feedback = Some(reply); + + self.tli + .update_replica_state(self.replica_id, self.feedback); + } + _ => warn!("unexpected message {:?}", msg), + } + Ok(()) } } const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); -// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn. +/// Wait until we have commit_lsn > lsn or timeout expires. Returns +/// - Ok(Some(commit_lsn)) if needed lsn is successfully observed; +/// - Ok(None) if timeout expired; +/// - Err in case of error (if watch channel is in trouble, shouldn't happen). async fn wait_for_lsn(rx: &mut Receiver, lsn: Lsn) -> anyhow::Result> { let commit_lsn: Lsn = *rx.borrow(); if commit_lsn > lsn { diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 98c565cde4..fca460d998 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -1,4 +1,4 @@ -//! This module implements Timeline lifecycle management and has all neccessary code +//! This module implements Timeline lifecycle management and has all necessary code //! to glue together SafeKeeper and all other background services. use anyhow::{anyhow, bail, Result}; @@ -532,7 +532,7 @@ impl Timeline { /// Register compute connection, starting timeline-related activity if it is /// not running yet. - pub fn on_compute_connect(&self) -> Result<()> { + pub async fn on_compute_connect(&self) -> Result<()> { if self.is_cancelled() { bail!(TimelineError::Cancelled(self.ttid)); } @@ -546,7 +546,7 @@ impl Timeline { // Wake up wal backup launcher, if offloading not started yet. if is_wal_backup_action_pending { // Can fail only if channel to a static thread got closed, which is not normal at all. - self.wal_backup_launcher_tx.blocking_send(self.ttid)?; + self.wal_backup_launcher_tx.send(self.ttid).await?; } Ok(()) } @@ -563,6 +563,11 @@ impl Timeline { // Wake up wal backup launcher, if it is time to stop the offloading. if is_wal_backup_action_pending { // Can fail only if channel to a static thread got closed, which is not normal at all. + // + // Note: this is blocking_send because on_compute_disconnect is called in Drop, there is + // no async Drop and we use current thread runtimes. With current thread rt spawning + // task in drop impl is racy, as thread along with runtime might finish before the task. + // This should be switched send.await when/if we go to full async. self.wal_backup_launcher_tx.blocking_send(self.ttid)?; } Ok(()) diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index c99ca0a51a..868ee97645 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -171,7 +171,7 @@ impl GlobalTimelines { /// Create a new timeline with the given id. If the timeline already exists, returns /// an existing timeline. - pub fn create( + pub async fn create( ttid: TenantTimelineId, server_info: ServerInfo, commit_lsn: Lsn, @@ -199,28 +199,20 @@ impl GlobalTimelines { // Take a lock and finish the initialization holding this mutex. No other threads // can interfere with creation after we will insert timeline into the map. - let mut shared_state = timeline.write_shared_state(); + { + let mut shared_state = timeline.write_shared_state(); - // We can get a race condition here in case of concurrent create calls, but only - // in theory. create() will return valid timeline on the next try. - TIMELINES_STATE - .lock() - .unwrap() - .try_insert(timeline.clone())?; + // We can get a race condition here in case of concurrent create calls, but only + // in theory. create() will return valid timeline on the next try. + TIMELINES_STATE + .lock() + .unwrap() + .try_insert(timeline.clone())?; - // Write the new timeline to the disk and start background workers. - // Bootstrap is transactional, so if it fails, the timeline will be deleted, - // and the state on disk should remain unchanged. - match timeline.bootstrap(&mut shared_state) { - Ok(_) => { - // We are done with bootstrap, release the lock, return the timeline. - drop(shared_state); - timeline - .wal_backup_launcher_tx - .blocking_send(timeline.ttid)?; - Ok(timeline) - } - Err(e) => { + // Write the new timeline to the disk and start background workers. + // Bootstrap is transactional, so if it fails, the timeline will be deleted, + // and the state on disk should remain unchanged. + if let Err(e) = timeline.bootstrap(&mut shared_state) { // Note: the most likely reason for bootstrap failure is that the timeline // directory already exists on disk. This happens when timeline is corrupted // and wasn't loaded from disk on startup because of that. We want to preserve @@ -232,9 +224,13 @@ impl GlobalTimelines { // Timeline failed to bootstrap, it cannot be used. Remove it from the map. TIMELINES_STATE.lock().unwrap().timelines.remove(&ttid); - Err(e) + return Err(e); } + // We are done with bootstrap, release the lock, return the timeline. + // {} block forces release before .await } + timeline.wal_backup_launcher_tx.send(timeline.ttid).await?; + Ok(timeline) } /// Get a timeline from the global map. If it's not present, it doesn't exist on disk, @@ -254,7 +250,7 @@ impl GlobalTimelines { } } - /// Returns all timelines. This is used for background timeline proccesses. + /// Returns all timelines. This is used for background timeline processes. pub fn get_all() -> Vec> { let global_lock = TIMELINES_STATE.lock().unwrap(); global_lock diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index fc971ca753..798b9abaf3 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -191,7 +191,7 @@ async fn wal_backup_launcher_main_loop( .map(|c| GenericRemoteStorage::from_config(c).expect("failed to create remote storage")) }); - // Presense in this map means launcher is aware s3 offloading is needed for + // Presence in this map means launcher is aware s3 offloading is needed for // the timeline, but task is started only if it makes sense for to offload // from this safekeeper. let mut tasks: HashMap = HashMap::new(); @@ -467,7 +467,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize pub async fn read_object( file_path: &RemotePath, offset: u64, -) -> anyhow::Result>> { +) -> anyhow::Result>> { let storage = REMOTE_STORAGE .get() .context("Failed to get remote storage")? diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 40448be949..8d63d604ad 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -2,50 +2,65 @@ //! WAL service listens for client connections and //! receive WAL from wal_proposer and send it to WAL receivers //! -use regex::Regex; -use std::net::{TcpListener, TcpStream}; -use std::thread; +use anyhow::{Context, Result}; +use nix::unistd::gettid; +use postgres_backend::QueryError; +use std::{future, thread}; +use tokio::net::TcpStream; use tracing::*; -use utils::postgres_backend::QueryError; use crate::handler::SafekeeperPostgresHandler; use crate::SafeKeeperConf; -use utils::postgres_backend::{AuthType, PostgresBackend}; +use postgres_backend::{AuthType, PostgresBackend}; /// Accept incoming TCP connections and spawn them into a background thread. -pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! { - loop { - match listener.accept() { - Ok((socket, peer_addr)) => { - debug!("accepted connection from {}", peer_addr); - let conf = conf.clone(); +pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .context("create runtime") + // todo catch error in main thread + .expect("failed to create runtime"); - let _ = thread::Builder::new() - .name("WAL service thread".into()) - .spawn(move || { - if let Err(err) = handle_socket(socket, conf) { - error!("connection handler exited: {}", err); - } - }) - .unwrap(); + runtime + .block_on(async move { + // Tokio's from_std won't do this for us, per its comment. + pg_listener.set_nonblocking(true)?; + let listener = tokio::net::TcpListener::from_std(pg_listener)?; + + loop { + match listener.accept().await { + Ok((socket, peer_addr)) => { + debug!("accepted connection from {}", peer_addr); + let conf = conf.clone(); + + let _ = thread::Builder::new() + .name("WAL service thread".into()) + .spawn(move || { + if let Err(err) = handle_socket(socket, conf) { + error!("connection handler exited: {}", err); + } + }) + .unwrap(); + } + Err(e) => error!("Failed to accept connection: {}", e), + } } - Err(e) => error!("Failed to accept connection: {}", e), - } - } -} - -// Get unique thread id (Rust internal), with ThreadId removed for shorter printing -fn get_tid() -> u64 { - let tids = format!("{:?}", thread::current().id()); - let r = Regex::new(r"ThreadId\((\d+)\)").unwrap(); - let caps = r.captures(&tids).unwrap(); - caps.get(1).unwrap().as_str().parse().unwrap() + #[allow(unreachable_code)] // hint compiler the closure return type + Ok::<(), anyhow::Error>(()) + }) + .expect("listener failed") } /// This is run by `thread_main` above, inside a background thread. /// fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> { - let _enter = info_span!("", tid = ?get_tid()).entered(); + let _enter = info_span!("", tid = %gettid()).entered(); + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + let local = tokio::task::LocalSet::new(); socket.set_nodelay(true)?; @@ -54,9 +69,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr Some(_) => AuthType::NeonJWT, }; let mut conn_handler = SafekeeperPostgresHandler::new(conf); - let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?; - // libpq replication protocol between safekeeper and replicas/pagers - pgbackend.run(&mut conn_handler)?; + let pgbackend = PostgresBackend::new(socket, auth_type, None)?; + // libpq protocol between safekeeper and walproposer / pageserver + // We don't use shutdown. + local.block_on( + &runtime, + pgbackend.run(&mut conn_handler, future::pending::<()>), + )?; Ok(()) } diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index ae02b3c7bc..9b385630c2 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -471,7 +471,7 @@ pub struct WalReader { timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn, - wal_segment: Option>>, + wal_segment: Option>>, // S3 will be used to read WAL if LSN is not available locally enable_remote_read: bool, @@ -538,7 +538,7 @@ impl WalReader { } /// Open WAL segment at the current position of the reader. - async fn open_segment(&self) -> Result>> { + async fn open_segment(&self) -> Result>> { let xlogoff = self.pos.segment_offset(self.wal_seg_size); let segno = self.pos.segment_number(self.wal_seg_size); let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size); diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index ba98563693..70a6f1809e 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2068,8 +2068,10 @@ class NeonPageserver(PgProtocol): ".*Connection aborted: connection error: error communicating with the server: Broken pipe.*", ".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*", ".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*", + # FIXME: replication patch for tokio_postgres regards any but CopyDone/CopyData message in CopyBoth stream as unexpected + ".*Connection aborted: connection error: unexpected message from server*", ".*kill_and_wait_impl.*: wait successful.*", - ".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*", + ".*Replication stream finished: db error:.*ending streaming to Some*", ".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down ".*query handler for 'pagestream.*failed: Connection reset by peer.*", # pageserver notices compute shut down # safekeeper connection can fail with this, in the window between timeline creation diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 0ac9127c6b..489afb7b93 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -1138,8 +1138,8 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): # FIXME: are these expected? env.pageserver.allowed_errors.extend( [ - ".*Failed to process query for timeline .*: Timeline .* was not found in global map.*", - ".*Failed to process query for timeline .*: Timeline .* was cancelled and cannot be used anymore.*", + ".*Timeline .* was not found in global map.*", + ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index bd21095fff..f885f4a94d 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -14,14 +14,19 @@ publish = false ### BEGIN HAKARI SECTION [dependencies] anyhow = { version = "1", features = ["backtrace"] } +byteorder = { version = "1" } bytes = { version = "1", features = ["serde"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } clap = { version = "4", features = ["derive", "string"] } crossbeam-utils = { version = "0.8" } +digest = { version = "0.10", features = ["mac", "std"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures = { version = "0.3" } +futures-channel = { version = "0.3", features = ["sink"] } +futures-core = { version = "0.3" } futures-executor = { version = "0.3" } +futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } hashbrown = { version = "0.12", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } @@ -45,6 +50,7 @@ serde = { version = "1", features = ["alloc", "derive"] } serde_json = { version = "1", features = ["raw_value"] } socket2 = { version = "0.4", default-features = false, features = ["all"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "sync", "time"] } +tokio-rustls = { version = "0.23" } tokio-util = { version = "0.7", features = ["codec", "io"] } tonic = { version = "0.8", features = ["tls-roots"] } tower = { version = "0.4", features = ["balance", "buffer", "limit", "retry", "timeout", "util"] }