diff --git a/Cargo.lock b/Cargo.lock index 02b03e02fb..0a505573e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -913,6 +913,7 @@ dependencies = [ "once_cell", "pageserver_api", "postgres", + "postgres_backend", "postgres_connection", "regex", "reqwest", @@ -2454,6 +2455,7 @@ dependencies = [ "postgres", "postgres-protocol", "postgres-types", + "postgres_backend", "postgres_connection", "postgres_ffi", "pq_proto", @@ -2676,6 +2678,28 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "postgres_backend" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bytes", + "futures", + "once_cell", + "pq_proto", + "rustls", + "rustls-pemfile", + "serde", + "thiserror", + "tokio", + "tokio-postgres", + "tokio-postgres-rustls", + "tokio-rustls", + "tracing", + "workspace_hack", +] + [[package]] name = "postgres_connection" version = "0.1.0" @@ -2723,7 +2747,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" name = "pq_proto" version = "0.1.0" dependencies = [ - "anyhow", + "byteorder", "bytes", "pin-project-lite", "postgres-protocol", @@ -2898,6 +2922,7 @@ dependencies = [ "opentelemetry", "parking_lot", "pin-project-lite", + "postgres_backend", "pq_proto", "prometheus", "rand", @@ -3277,15 +3302,6 @@ dependencies = [ "base64 0.21.0", ] -[[package]] -name = "rustls-split" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3" -dependencies = [ - "rustls", -] - [[package]] name = "rustversion" version = "1.0.11" @@ -3321,6 +3337,7 @@ dependencies = [ "parking_lot", "postgres", "postgres-protocol", + "postgres_backend", "postgres_ffi", "pq_proto", "regex", @@ -4506,7 +4523,6 @@ dependencies = [ "bytes", "criterion", "futures", - "git-version", "heapless", "hex", "hex-literal", @@ -4515,12 +4531,9 @@ dependencies = [ "metrics", "nix", "once_cell", - "pq_proto", "rand", "routerify", "rustls", - "rustls-pemfile", - "rustls-split", "sentry", "serde", "serde_json", @@ -4835,14 +4848,19 @@ name = "workspace_hack" version = "0.1.0" dependencies = [ "anyhow", + "byteorder", "bytes", "chrono", "clap 4.1.4", "crossbeam-utils", + "digest", "either", "fail", "futures", + "futures-channel", + "futures-core", "futures-executor", + "futures-sink", "futures-util", "hashbrown 0.12.3", "indexmap", diff --git a/Cargo.toml b/Cargo.toml index ea22b04124..bbd4975603 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -133,6 +133,7 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" } metrics = { version = "0.1", path = "./libs/metrics/" } pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" } +postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" } postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" } postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 309887e1fa..ba39747e03 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -24,6 +24,7 @@ url.workspace = true # Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api # instead, so that recompile times are better. pageserver_api.workspace = true +postgres_backend.workspace = true safekeeper_api.workspace = true postgres_connection.workspace = true storage_broker.workspace = true diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 4b2aa3c957..49b1d31dbc 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -17,6 +17,7 @@ use pageserver_api::{ DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR, DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR, }; +use postgres_backend::AuthType; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT, @@ -30,7 +31,6 @@ use utils::{ auth::{Claims, Scope}, id::{NodeId, TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, project_git_version, }; diff --git a/control_plane/src/compute.rs b/control_plane/src/compute.rs index 8731cf2583..b7029aabc5 100644 --- a/control_plane/src/compute.rs +++ b/control_plane/src/compute.rs @@ -11,10 +11,10 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{Context, Result}; +use postgres_backend::AuthType; use utils::{ id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, }; use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION}; diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 003152c578..09180d96c4 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -5,6 +5,7 @@ use anyhow::{bail, ensure, Context}; +use postgres_backend::AuthType; use reqwest::Url; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -19,7 +20,6 @@ use std::process::{Command, Stdio}; use utils::{ auth::{encode_from_key_file, Claims, Scope}, id::{NodeId, TenantId, TenantTimelineId, TimelineId}, - postgres_backend::AuthType, }; use crate::safekeeper::SafekeeperNode; diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index c49bd39f09..4b7180c250 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -11,6 +11,7 @@ use anyhow::{bail, Context}; use pageserver_api::models::{ TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo, }; +use postgres_backend::AuthType; use postgres_connection::{parse_host_port, PgConnectionConfig}; use reqwest::blocking::{Client, RequestBuilder, Response}; use reqwest::{IntoUrl, Method}; @@ -20,7 +21,6 @@ use utils::{ http::error::HttpErrorBody, id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, }; use crate::{background_process, local_env::LocalEnv}; diff --git a/libs/postgres_backend/Cargo.toml b/libs/postgres_backend/Cargo.toml new file mode 100644 index 0000000000..8e249c09f7 --- /dev/null +++ b/libs/postgres_backend/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "postgres_backend" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +async-trait.workspace = true +anyhow.workspace = true +bytes.workspace = true +futures.workspace = true +rustls.workspace = true +serde.workspace = true +thiserror.workspace = true +tokio.workspace = true +tokio-rustls.workspace = true +tracing.workspace = true + +pq_proto.workspace = true +workspace_hack.workspace = true + +[dev-dependencies] +once_cell.workspace = true +rustls-pemfile.workspace = true +tokio-postgres.workspace = true +tokio-postgres-rustls.workspace = true \ No newline at end of file diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs new file mode 100644 index 0000000000..15a9fc6dd0 --- /dev/null +++ b/libs/postgres_backend/src/lib.rs @@ -0,0 +1,911 @@ +//! Server-side asynchronous Postgres connection, as limited as we need. +//! To use, create PostgresBackend and run() it, passing the Handler +//! implementation determining how to process the queries. Currently its API +//! is rather narrow, but we can extend it once required. +use anyhow::Context; +use bytes::Bytes; +use futures::pin_mut; +use serde::{Deserialize, Serialize}; +use std::io::ErrorKind; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Poll}; +use std::{fmt, io}; +use std::{future::Future, str::FromStr}; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio_rustls::TlsAcceptor; + +use tracing::{debug, error, info, trace}; + +use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter}; +use pq_proto::{ + BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR, + SQLSTATE_SUCCESSFUL_COMPLETION, +}; + +/// An error, occurred during query processing: +/// either during the connection ([`ConnectionError`]) or before/after it. +#[derive(thiserror::Error, Debug)] +pub enum QueryError { + /// The connection was lost while processing the query. + #[error(transparent)] + Disconnected(#[from] ConnectionError), + /// Some other error + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for QueryError { + fn from(e: io::Error) -> Self { + Self::Disconnected(ConnectionError::Io(e)) + } +} + +impl QueryError { + pub fn pg_error_code(&self) -> &'static [u8; 5] { + match self { + Self::Disconnected(_) => b"08006", // connection failure + Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error + } + } +} + +pub fn is_expected_io_error(e: &io::Error) -> bool { + use io::ErrorKind::*; + matches!( + e.kind(), + ConnectionRefused | ConnectionAborted | ConnectionReset + ) +} + +#[async_trait::async_trait] +pub trait Handler { + /// Handle single query. + /// postgres_backend will issue ReadyForQuery after calling this (this + /// might be not what we want after CopyData streaming, but currently we don't + /// care). It will also flush out the output buffer. + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query_string: &str, + ) -> Result<(), QueryError>; + + /// Called on startup packet receival, allows to process params. + /// + /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users + /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow + /// to override whole init logic in implementations. + fn startup( + &mut self, + _pgb: &mut PostgresBackend, + _sm: &FeStartupPacket, + ) -> Result<(), QueryError> { + Ok(()) + } + + /// Check auth jwt + fn check_auth_jwt( + &mut self, + _pgb: &mut PostgresBackend, + _jwt_response: &[u8], + ) -> Result<(), QueryError> { + Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) + } +} + +/// PostgresBackend protocol state. +/// XXX: The order of the constructors matters. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum ProtoState { + /// Nothing happened yet. + Initialization, + /// Encryption handshake is done; waiting for encrypted Startup message. + Encrypted, + /// Waiting for password (auth token). + Authentication, + /// Performed handshake and auth, ReadyForQuery is issued. + Established, + Closed, +} + +#[derive(Clone, Copy)] +pub enum ProcessMsgResult { + Continue, + Break, +} + +/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite. +pub enum MaybeTlsStream { + Unencrypted(tokio::net::TcpStream), + Tls(Box>), +} + +impl AsyncWrite for MaybeTlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf), + Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx), + Self::Tls(stream) => Pin::new(stream).poll_flush(cx), + } + } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx), + Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} +impl AsyncRead for MaybeTlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf), + Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum AuthType { + Trust, + // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT + NeonJWT, +} + +impl FromStr for AuthType { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "Trust" => Ok(Self::Trust), + "NeonJWT" => Ok(Self::NeonJWT), + _ => anyhow::bail!("invalid value \"{s}\" for auth type"), + } + } +} + +impl fmt::Display for AuthType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + AuthType::Trust => "Trust", + AuthType::NeonJWT => "NeonJWT", + }) + } +} + +/// Either full duplex Framed or write only half; the latter is left in +/// PostgresBackend after call to `split`. In principle we could always store a +/// pair of splitted handles, but that would force to to pay splitting price +/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver). +enum MaybeWriteOnly { + Full(Framed), + WriteOnly(FramedWriter>), + Broken, // temporary value palmed off during the split +} + +impl MaybeWriteOnly { + async fn read_startup_message(&mut self) -> Result, ConnectionError> { + match self { + MaybeWriteOnly::Full(framed) => framed.read_startup_message().await, + MaybeWriteOnly::WriteOnly(_) => { + Err(io::Error::new(ErrorKind::Other, "reading from write only half").into()) + } + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn read_message(&mut self) -> Result, ConnectionError> { + match self { + MaybeWriteOnly::Full(framed) => framed.read_message().await, + MaybeWriteOnly::WriteOnly(_) => { + Err(io::Error::new(ErrorKind::Other, "reading from write only half").into()) + } + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + match self { + MaybeWriteOnly::Full(framed) => framed.write_message(msg), + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg), + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn flush(&mut self) -> io::Result<()> { + match self { + MaybeWriteOnly::Full(framed) => framed.flush().await, + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await, + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn shutdown(&mut self) -> io::Result<()> { + match self { + MaybeWriteOnly::Full(framed) => framed.shutdown().await, + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await, + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } +} + +pub struct PostgresBackend { + framed: MaybeWriteOnly, + + pub state: ProtoState, + + auth_type: AuthType, + + peer_addr: SocketAddr, + pub tls_config: Option>, +} + +pub fn query_from_cstring(query_string: Bytes) -> Vec { + let mut query_string = query_string.to_vec(); + if let Some(ch) = query_string.last() { + if *ch == 0 { + query_string.pop(); + } + } + query_string +} + +/// Cast a byte slice to a string slice, dropping null terminator if there's one. +fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { + let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); + std::str::from_utf8(without_null).map_err(|e| e.into()) +} + +impl PostgresBackend { + pub fn new( + socket: tokio::net::TcpStream, + auth_type: AuthType, + tls_config: Option>, + ) -> io::Result { + let peer_addr = socket.peer_addr()?; + let stream = MaybeTlsStream::Unencrypted(socket); + + Ok(Self { + framed: MaybeWriteOnly::Full(Framed::new(stream)), + state: ProtoState::Initialization, + auth_type, + tls_config, + peer_addr, + }) + } + + pub fn get_peer_addr(&self) -> &SocketAddr { + &self.peer_addr + } + + /// Read full message or return None if connection is cleanly closed with no + /// unprocessed data. + pub async fn read_message(&mut self) -> Result, ConnectionError> { + if let ProtoState::Closed = self.state { + Ok(None) + } else { + let m = self.framed.read_message().await?; + trace!("read msg {:?}", m); + Ok(m) + } + } + + /// Write message into internal output buffer, doesn't flush it. Technically + /// error type can be only ProtocolError here (if, unlikely, serialization + /// fails), but callers typically wrap it anyway. + pub fn write_message_noflush( + &mut self, + message: &BeMessage<'_>, + ) -> Result<&mut Self, ConnectionError> { + self.framed.write_message_noflush(message)?; + trace!("wrote msg {:?}", message); + Ok(self) + } + + /// Flush output buffer into the socket. + pub async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } + + /// Polling version of `flush()`, saves the caller need to pin. + pub fn poll_flush( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let flush_fut = self.flush(); + pin_mut!(flush_fut); + flush_fut.poll(cx) + } + + /// Write message into internal output buffer and flush it to the stream. + pub async fn write_message( + &mut self, + message: &BeMessage<'_>, + ) -> Result<&mut Self, ConnectionError> { + self.write_message_noflush(message)?; + self.flush().await?; + Ok(self) + } + + /// Returns an AsyncWrite implementation that wraps all the data written + /// to it in CopyData messages, and writes them to the connection + /// + /// The caller is responsible for sending CopyOutResponse and CopyDone messages. + pub fn copyout_writer(&mut self) -> CopyDataWriter { + CopyDataWriter { pgb: self } + } + + /// Wrapper for run_message_loop() that shuts down socket when we are done + pub async fn run( + mut self, + handler: &mut impl Handler, + shutdown_watcher: F, + ) -> Result<(), QueryError> + where + F: Fn() -> S, + S: Future, + { + let ret = self.run_message_loop(handler, shutdown_watcher).await; + // socket might be already closed, e.g. if previously received error, + // so ignore result. + self.framed.shutdown().await.ok(); + ret + } + + async fn run_message_loop( + &mut self, + handler: &mut impl Handler, + shutdown_watcher: F, + ) -> Result<(), QueryError> + where + F: Fn() -> S, + S: Future, + { + trace!("postgres backend to {:?} started", self.peer_addr); + + tokio::select!( + biased; + + _ = shutdown_watcher() => { + // We were requested to shut down. + tracing::info!("shutdown request received during handshake"); + return Ok(()) + }, + + result = self.handshake(handler) => { + // Handshake complete. + result?; + if self.state == ProtoState::Closed { + return Ok(()); // EOF during handshake + } + } + ); + + // Authentication completed + let mut query_string = Bytes::new(); + while let Some(msg) = tokio::select!( + biased; + _ = shutdown_watcher() => { + // We were requested to shut down. + tracing::info!("shutdown request received in run_message_loop"); + Ok(None) + }, + msg = self.read_message() => { msg }, + )? { + trace!("got message {:?}", msg); + + let result = self.process_message(handler, msg, &mut query_string).await; + self.flush().await?; + match result? { + ProcessMsgResult::Continue => { + self.flush().await?; + continue; + } + ProcessMsgResult::Break => break, + } + } + + trace!("postgres backend to {:?} exited", self.peer_addr); + Ok(()) + } + + /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake. + async fn tls_upgrade( + src: MaybeTlsStream, + tls_config: Arc, + ) -> anyhow::Result { + match src { + MaybeTlsStream::Unencrypted(s) => { + let acceptor = TlsAcceptor::from(tls_config); + let tls_stream = acceptor.accept(s).await?; + Ok(MaybeTlsStream::Tls(Box::new(tls_stream))) + } + MaybeTlsStream::Tls(_) => { + anyhow::bail!("TLS already started"); + } + } + } + + async fn start_tls(&mut self) -> anyhow::Result<()> { + // temporary replace stream with fake to cook TLS one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(framed) => { + let tls_config = self + .tls_config + .as_ref() + .context("start_tls called without conf")? + .clone(); + let tls_framed = framed + .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config)) + .await?; + // push back ready TLS stream + self.framed = MaybeWriteOnly::Full(tls_framed); + Ok(()) + } + MaybeWriteOnly::WriteOnly(_) => { + anyhow::bail!("TLS upgrade attempt in split state") + } + MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"), + } + } + + /// Split off owned read part from which messages can be read in different + /// task/thread. + pub fn split(&mut self) -> anyhow::Result { + // temporary replace stream with fake to cook split one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(framed) => { + let (reader, writer) = framed.split(); + self.framed = MaybeWriteOnly::WriteOnly(writer); + Ok(PostgresBackendReader(reader)) + } + MaybeWriteOnly::WriteOnly(_) => { + anyhow::bail!("PostgresBackend is already split") + } + MaybeWriteOnly::Broken => panic!("split on framed in invalid state"), + } + } + + /// Join read part back. + pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> { + // temporary replace stream with fake to cook joined one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(_) => { + anyhow::bail!("PostgresBackend is not split") + } + MaybeWriteOnly::WriteOnly(writer) => { + let joined = Framed::unsplit(reader.0, writer); + self.framed = MaybeWriteOnly::Full(joined); + Ok(()) + } + MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"), + } + } + + /// Perform handshake with the client, transitioning to Established. + /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()). + async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> { + while self.state < ProtoState::Authentication { + match self.framed.read_startup_message().await? { + Some(msg) => { + self.process_startup_message(handler, msg).await?; + } + None => { + trace!( + "postgres backend to {:?} received EOF during handshake", + self.peer_addr + ); + self.state = ProtoState::Closed; + return Ok(()); + } + } + } + + // Perform auth, if needed. + if self.state == ProtoState::Authentication { + match self.framed.read_message().await? { + Some(FeMessage::PasswordMessage(m)) => { + assert!(self.auth_type == AuthType::NeonJWT); + + let (_, jwt_response) = m.split_last().context("protocol violation")?; + + if let Err(e) = handler.check_auth_jwt(self, jwt_response) { + self.write_message_noflush(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; + return Err(e); + } + + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeMessage::CLIENT_ENCODING)? + .write_message(&BeMessage::ReadyForQuery) + .await?; + self.state = ProtoState::Established; + } + Some(m) => { + return Err(QueryError::Other(anyhow::anyhow!( + "Unexpected message {:?} while waiting for handshake", + m + ))); + } + None => { + trace!( + "postgres backend to {:?} received EOF during auth", + self.peer_addr + ); + self.state = ProtoState::Closed; + return Ok(()); + } + } + } + + Ok(()) + } + + /// Process startup packet: + /// - transition to Established if auth type is trust + /// - transition to Authentication if auth type is NeonJWT. + /// - or perform TLS handshake -- then need to call this again to receive + /// actual startup packet. + async fn process_startup_message( + &mut self, + handler: &mut impl Handler, + msg: FeStartupPacket, + ) -> Result<(), QueryError> { + assert!(self.state < ProtoState::Authentication); + let have_tls = self.tls_config.is_some(); + match msg { + FeStartupPacket::SslRequest => { + debug!("SSL requested"); + + self.write_message(&BeMessage::EncryptionResponse(have_tls)) + .await?; + + if have_tls { + self.start_tls().await?; + self.state = ProtoState::Encrypted; + } + } + FeStartupPacket::GssEncRequest => { + debug!("GSS requested"); + self.write_message(&BeMessage::EncryptionResponse(false)) + .await?; + } + FeStartupPacket::StartupMessage { .. } => { + if have_tls && !matches!(self.state, ProtoState::Encrypted) { + self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None)) + .await?; + return Err(QueryError::Other(anyhow::anyhow!( + "client did not connect with TLS" + ))); + } + + // NB: startup() may change self.auth_type -- we are using that in proxy code + // to bypass auth for new users. + handler.startup(self, &msg)?; + + match self.auth_type { + AuthType::Trust => { + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeMessage::CLIENT_ENCODING)? + .write_message_noflush(&BeMessage::INTEGER_DATETIMES)? + // The async python driver requires a valid server_version + .write_message_noflush(&BeMessage::server_version("14.1"))? + .write_message(&BeMessage::ReadyForQuery) + .await?; + self.state = ProtoState::Established; + } + AuthType::NeonJWT => { + self.write_message(&BeMessage::AuthenticationCleartextPassword) + .await?; + self.state = ProtoState::Authentication; + } + } + } + FeStartupPacket::CancelRequest { .. } => { + return Err(QueryError::Other(anyhow::anyhow!( + "Unexpected CancelRequest message during handshake" + ))); + } + } + Ok(()) + } + + async fn process_message( + &mut self, + handler: &mut impl Handler, + msg: FeMessage, + unnamed_query_string: &mut Bytes, + ) -> Result { + // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth + // TODO: change that to proper top-level match of protocol state with separate message handling for each state + assert!(self.state == ProtoState::Established); + + match msg { + FeMessage::Query(body) => { + // remove null terminator + let query_string = cstr_to_str(&body)?; + + trace!("got query {query_string:?}"); + if let Err(e) = handler.process_query(self, query_string).await { + log_query_error(query_string, &e); + let short_error = short_error(&e); + self.write_message_noflush(&BeMessage::ErrorResponse( + &short_error, + Some(e.pg_error_code()), + ))?; + } + self.write_message_noflush(&BeMessage::ReadyForQuery)?; + } + + FeMessage::Parse(m) => { + *unnamed_query_string = m.query_string; + self.write_message_noflush(&BeMessage::ParseComplete)?; + } + + FeMessage::Describe(_) => { + self.write_message_noflush(&BeMessage::ParameterDescription)? + .write_message_noflush(&BeMessage::NoData)?; + } + + FeMessage::Bind(_) => { + self.write_message_noflush(&BeMessage::BindComplete)?; + } + + FeMessage::Close(_) => { + self.write_message_noflush(&BeMessage::CloseComplete)?; + } + + FeMessage::Execute(_) => { + let query_string = cstr_to_str(unnamed_query_string)?; + trace!("got execute {query_string:?}"); + if let Err(e) = handler.process_query(self, query_string).await { + log_query_error(query_string, &e); + self.write_message_noflush(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; + } + // NOTE there is no ReadyForQuery message. This handler is used + // for basebackup and it uses CopyOut which doesn't require + // ReadyForQuery message and backend just switches back to + // processing mode after sending CopyDone or ErrorResponse. + } + + FeMessage::Sync => { + self.write_message_noflush(&BeMessage::ReadyForQuery)?; + } + + FeMessage::Terminate => { + return Ok(ProcessMsgResult::Break); + } + + // We prefer explicit pattern matching to wildcards, because + // this helps us spot the places where new variants are missing + FeMessage::CopyData(_) + | FeMessage::CopyDone + | FeMessage::CopyFail + | FeMessage::PasswordMessage(_) => { + return Err(QueryError::Other(anyhow::anyhow!( + "unexpected message type: {msg:?}", + ))); + } + } + + Ok(ProcessMsgResult::Continue) + } + + /// Log as info/error result of handling COPY stream and send back + /// ErrorResponse if that makes sense. Shutdown the stream if we got + /// Terminate. TODO: transition into waiting for Sync msg if we initiate the + /// close. + pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) { + use CopyStreamHandlerEnd::*; + + let expected_end = match &end { + ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true, + CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error)) + if is_expected_io_error(io_error) => + { + true + } + _ => false, + }; + if expected_end { + info!("terminated: {:#}", end); + } else { + error!("terminated: {:?}", end); + } + + // Note: no current usages ever send this + if let CopyDone = &end { + if let Err(e) = self.write_message(&BeMessage::CopyDone).await { + error!("failed to send CopyDone: {}", e); + } + } + + if let Terminate = &end { + self.state = ProtoState::Closed; + } + + let err_to_send_and_errcode = match &end { + ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)), + Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)), + // Note: CopyFail in duplex copy is somewhat unexpected (at least to + // PG walsender; evidently and per my docs reading client should + // finish it with CopyDone). It is not a problem to recover from it + // finishing the stream in both directions like we do, but note that + // sync rust-postgres client (which we don't use anymore) hangs if + // socket is not closed here. + // https://github.com/sfackler/rust-postgres/issues/755 + // https://github.com/neondatabase/neon/issues/935 + // + // Currently, the version of tokio_postgres replication patch we use + // sends this when it closes the stream (e.g. pageserver decided to + // switch conn to another safekeeper and client gets dropped). + // Moreover, seems like 'connection' task errors with 'unexpected + // message from server' when it receives ErrorResponse (anything but + // CopyData/CopyDone) back. + CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)), + _ => None, + }; + if let Some((err, errcode)) = err_to_send_and_errcode { + if let Err(ee) = self + .write_message(&BeMessage::ErrorResponse(&err, Some(errcode))) + .await + { + error!("failed to send ErrorResponse: {}", ee); + } + } + } +} + +pub struct PostgresBackendReader(FramedReader>); + +impl PostgresBackendReader { + /// Read full message or return None if connection is cleanly closed with no + /// unprocessed data. + pub async fn read_message(&mut self) -> Result, ConnectionError> { + let m = self.0.read_message().await?; + trace!("read msg {:?}", m); + Ok(m) + } + + /// Get CopyData contents of the next message in COPY stream or error + /// closing it. The error type is wider than actual errors which can happen + /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for + /// current callers. + pub async fn read_copy_message(&mut self) -> Result { + match self.read_message().await? { + Some(msg) => match msg { + FeMessage::CopyData(m) => Ok(m), + FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone), + FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail), + FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate), + _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol( + ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)), + ))), + }, + None => Err(CopyStreamHandlerEnd::EOF), + } + } +} + +/// +/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData +/// messages. +/// + +pub struct CopyDataWriter<'a> { + pgb: &'a mut PostgresBackend, +} + +impl<'a> AsyncWrite for CopyDataWriter<'a> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + + // It's not strictly required to flush between each message, but makes it easier + // to view in wireshark, and usually the messages that the callers write are + // decently-sized anyway. + if let Err(err) = ready!(this.pgb.poll_flush(cx)) { + return Poll::Ready(Err(err)); + } + + // CopyData + // XXX: if the input is large, we should split it into multiple messages. + // Not sure what the threshold should be, but the ultimate hard limit is that + // the length cannot exceed u32. + this.pgb + .write_message_noflush(&BeMessage::CopyData(buf)) + // write_message only writes to the buffer, so it can fail iff the + // message is invaid, but CopyData can't be invalid. + .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?; + + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + this.pgb.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + this.pgb.poll_flush(cx) + } +} + +pub fn short_error(e: &QueryError) -> String { + match e { + QueryError::Disconnected(connection_error) => connection_error.to_string(), + QueryError::Other(e) => format!("{e:#}"), + } +} + +fn log_query_error(query: &str, e: &QueryError) { + match e { + QueryError::Disconnected(ConnectionError::Io(io_error)) => { + if is_expected_io_error(io_error) { + info!("query handler for '{query}' failed with expected io error: {io_error}"); + } else { + error!("query handler for '{query}' failed with io error: {io_error}"); + } + } + QueryError::Disconnected(other_connection_error) => { + error!("query handler for '{query}' failed with connection error: {other_connection_error:?}") + } + QueryError::Other(e) => { + error!("query handler for '{query}' failed: {e:?}"); + } + } +} + +/// Something finishing handling of COPY stream, see handle_copy_stream_end. +/// This is not always a real error, but it allows to use ? and thiserror impls. +#[derive(thiserror::Error, Debug)] +pub enum CopyStreamHandlerEnd { + /// Handler initiates the end of streaming. + #[error("{0}")] + ServerInitiated(String), + #[error("received CopyDone")] + CopyDone, + #[error("received CopyFail")] + CopyFail, + #[error("received Terminate")] + Terminate, + #[error("EOF on COPY stream")] + EOF, + /// The connection was lost + #[error(transparent)] + Disconnected(#[from] ConnectionError), + /// Some other error + #[error(transparent)] + Other(#[from] anyhow::Error), +} diff --git a/libs/utils/tests/cert.pem b/libs/postgres_backend/tests/cert.pem similarity index 100% rename from libs/utils/tests/cert.pem rename to libs/postgres_backend/tests/cert.pem diff --git a/libs/utils/tests/key.pem b/libs/postgres_backend/tests/key.pem similarity index 100% rename from libs/utils/tests/key.pem rename to libs/postgres_backend/tests/key.pem diff --git a/libs/postgres_backend/tests/simple_select.rs b/libs/postgres_backend/tests/simple_select.rs new file mode 100644 index 0000000000..a310171c70 --- /dev/null +++ b/libs/postgres_backend/tests/simple_select.rs @@ -0,0 +1,139 @@ +/// Test postgres_backend_async with tokio_postgres +use once_cell::sync::Lazy; +use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError}; +use pq_proto::{BeMessage, RowDescriptor}; +use std::io::Cursor; +use std::{future, sync::Arc}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_postgres::config::SslMode; +use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::{Config, NoTls, SimpleQueryMessage}; +use tokio_postgres_rustls::MakeRustlsConnect; + +// generate client, server test streams +async fn make_tcp_pair() -> (TcpStream, TcpStream) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let client_stream = TcpStream::connect(addr).await.unwrap(); + let (server_stream, _) = listener.accept().await.unwrap(); + (client_stream, server_stream) +} + +struct TestHandler {} + +#[async_trait::async_trait] +impl Handler for TestHandler { + // return single col 'hey' for any query + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + _query_string: &str, + ) -> Result<(), QueryError> { + pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col( + b"hey", + )]))? + .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + Ok(()) + } +} + +// test that basic select works +#[tokio::test] +async fn simple_select() { + let (client_sock, server_sock) = make_tcp_pair().await; + + // create and run pgbackend + let pgbackend = + PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation"); + + tokio::spawn(async move { + let mut handler = TestHandler {}; + pgbackend.run(&mut handler, future::pending::<()>).await + }); + + let conf = Config::new(); + let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect"); + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0]; + if let SimpleQueryMessage::Row(row) = first_val { + let first_col = row.get(0).expect("first column"); + assert_eq!(first_col, "hey"); + } else { + panic!("expected SimpleQueryMessage::Row"); + } +} + +static KEY: Lazy = Lazy::new(|| { + let mut cursor = Cursor::new(include_bytes!("key.pem")); + rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()) +}); + +static CERT: Lazy = Lazy::new(|| { + let mut cursor = Cursor::new(include_bytes!("cert.pem")); + rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) +}); + +// test that basic select with ssl works +#[tokio::test] +async fn simple_select_ssl() { + let (client_sock, server_sock) = make_tcp_pair().await; + + let server_cfg = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![CERT.clone()], KEY.clone()) + .unwrap(); + let tls_config = Some(Arc::new(server_cfg)); + let pgbackend = + PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation"); + + tokio::spawn(async move { + let mut handler = TestHandler {}; + pgbackend.run(&mut handler, future::pending::<()>).await + }); + + let client_cfg = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates({ + let mut store = rustls::RootCertStore::empty(); + store.add(&CERT).unwrap(); + store + }) + .with_no_client_auth(); + let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg); + let tls_connect = >::make_tls_connect( + &mut make_tls_connect, + "localhost", + ) + .expect("make_tls_connect"); + + let mut conf = Config::new(); + conf.ssl_mode(SslMode::Require); + let (client, connection) = conf + .connect_raw(client_sock, tls_connect) + .await + .expect("connect"); + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0]; + if let SimpleQueryMessage::Row(row) = first_val { + let first_col = row.get(0).expect("first column"); + assert_eq!(first_col, "hey"); + } else { + panic!("expected SimpleQueryMessage::Row"); + } +} diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index bc90a7a2c1..76b71729ed 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -5,8 +5,8 @@ edition.workspace = true license.workspace = true [dependencies] -anyhow.workspace = true bytes.workspace = true +byteorder.workspace = true pin-project-lite.workspace = true postgres-protocol.workspace = true rand.workspace = true diff --git a/libs/pq_proto/src/framed.rs b/libs/pq_proto/src/framed.rs new file mode 100644 index 0000000000..4b8a03fefc --- /dev/null +++ b/libs/pq_proto/src/framed.rs @@ -0,0 +1,251 @@ +//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from +//! the async stream based on (and buffered with) BytesMut. All functions are +//! cancellation safe. +//! +//! It is similar to what tokio_util::codec::Framed with appropriate codec +//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used +//! separately without using split from futures::stream::StreamExt (which +//! allocates box[1] in polling internally). tokio::io::split is used for splitting +//! instead. Plus we customize error messages more than a single type for all io +//! calls. +//! +//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107 +use bytes::{Buf, BytesMut}; +use std::{ + future::Future, + io::{self, ErrorKind}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; + +use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; + +const INITIAL_CAPACITY: usize = 8 * 1024; + +/// Error on postgres connection: either IO (physical transport error) or +/// protocol violation. +#[derive(thiserror::Error, Debug)] +pub enum ConnectionError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Protocol(#[from] ProtocolError), +} + +impl ConnectionError { + /// Proxy stream.rs uses only io::Error; provide it. + pub fn into_io_error(self) -> io::Error { + match self { + ConnectionError::Io(io) => io, + ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()), + } + } +} + +/// Wraps async io `stream`, providing messages to write/flush + read Postgres +/// messages. +pub struct Framed { + stream: S, + read_buf: BytesMut, + write_buf: BytesMut, +} + +impl Framed { + pub fn new(stream: S) -> Self { + Self { + stream, + read_buf: BytesMut::with_capacity(INITIAL_CAPACITY), + write_buf: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } + + /// Get a shared reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Extract the underlying stream. + pub fn into_inner(self) -> S { + self.stream + } + + /// Return new Framed with stream type transformed by async f, for TLS + /// upgrade. + pub async fn map_stream(self, f: F) -> Result, E> + where + F: FnOnce(S) -> Fut, + Fut: Future>, + { + let stream = f(self.stream).await?; + Ok(Framed { + stream, + read_buf: self.read_buf, + write_buf: self.write_buf, + }) + } +} + +impl Framed { + pub async fn read_startup_message( + &mut self, + ) -> Result, ConnectionError> { + read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await + } + + pub async fn read_message(&mut self) -> Result, ConnectionError> { + read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await + } +} + +impl Framed { + /// Write next message to the output buffer; doesn't flush. + pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + BeMessage::write(&mut self.write_buf, msg) + } + + /// Flush out the buffer. This function is cancellation safe: it can be + /// interrupted and flushing will be continued in the next call. + pub async fn flush(&mut self) -> Result<(), io::Error> { + flush(&mut self.stream, &mut self.write_buf).await + } + + /// Flush out the buffer and shutdown the stream. + pub async fn shutdown(&mut self) -> Result<(), io::Error> { + shutdown(&mut self.stream, &mut self.write_buf).await + } +} + +impl Framed { + /// Split into owned read and write parts. Beware of potential issues with + /// using halves in different tasks on TLS stream: + /// https://github.com/tokio-rs/tls/issues/40 + pub fn split(self) -> (FramedReader>, FramedWriter>) { + let (read_half, write_half) = tokio::io::split(self.stream); + let reader = FramedReader { + stream: read_half, + read_buf: self.read_buf, + }; + let writer = FramedWriter { + stream: write_half, + write_buf: self.write_buf, + }; + (reader, writer) + } + + /// Join read and write parts back. + pub fn unsplit(reader: FramedReader>, writer: FramedWriter>) -> Self { + Self { + stream: reader.stream.unsplit(writer.stream), + read_buf: reader.read_buf, + write_buf: writer.write_buf, + } + } +} + +/// Read-only version of `Framed`. +pub struct FramedReader { + stream: S, + read_buf: BytesMut, +} + +impl FramedReader { + pub async fn read_message(&mut self) -> Result, ConnectionError> { + read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await + } +} + +/// Write-only version of `Framed`. +pub struct FramedWriter { + stream: S, + write_buf: BytesMut, +} + +impl FramedWriter { + /// Get a mut reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } +} + +impl FramedWriter { + /// Write next message to the output buffer; doesn't flush. + pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + BeMessage::write(&mut self.write_buf, msg) + } + + /// Flush out the buffer. This function is cancellation safe: it can be + /// interrupted and flushing will be continued in the next call. + pub async fn flush(&mut self) -> Result<(), io::Error> { + flush(&mut self.stream, &mut self.write_buf).await + } + + /// Flush out the buffer and shutdown the stream. + pub async fn shutdown(&mut self) -> Result<(), io::Error> { + shutdown(&mut self.stream, &mut self.write_buf).await + } +} + +/// Read next message from the stream. Returns Ok(None), if EOF happened and we +/// don't have remaining data in the buffer. This function is cancellation safe: +/// you can drop future which is not yet complete and finalize reading message +/// with the next call. +/// +/// Parametrized to allow reading startup or usual message, having different +/// format. +async fn read_message( + stream: &mut S, + read_buf: &mut BytesMut, + parse: P, +) -> Result, ConnectionError> +where + P: Fn(&mut BytesMut) -> Result, ProtocolError>, +{ + loop { + if let Some(msg) = parse(read_buf)? { + return Ok(Some(msg)); + } + // If we can't build a frame yet, try to read more data and try again. + // Make sure we've got room for at least one byte to read to ensure + // that we don't get a spurious 0 that looks like EOF. + read_buf.reserve(1); + if stream.read_buf(read_buf).await? == 0 { + if read_buf.has_remaining() { + return Err(io::Error::new( + ErrorKind::UnexpectedEof, + "EOF with unprocessed data in the buffer", + ) + .into()); + } else { + return Ok(None); // clean EOF + } + } + } +} + +async fn flush( + stream: &mut S, + write_buf: &mut BytesMut, +) -> Result<(), io::Error> { + while write_buf.has_remaining() { + let bytes_written = stream.write(write_buf.chunk()).await?; + if bytes_written == 0 { + return Err(io::Error::new( + ErrorKind::WriteZero, + "failed to write message", + )); + } + // The advanced part will be garbage collected, likely during shifting + // data left on next attempt to write to buffer when free space is not + // enough. + write_buf.advance(bytes_written); + } + write_buf.clear(); + stream.flush().await +} + +async fn shutdown( + stream: &mut S, + write_buf: &mut BytesMut, +) -> Result<(), io::Error> { + flush(stream, write_buf).await?; + stream.shutdown().await +} diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index b7995c840c..3ebb14de5a 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -2,24 +2,18 @@ //! //! on message formats. -// Tools for calling certain async methods in sync contexts. -pub mod sync; +pub mod framed; -use anyhow::{ensure, Context, Result}; +use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_protocol::PG_EPOCH; use serde::{Deserialize, Serialize}; use std::{ borrow::Cow, collections::HashMap, - fmt, - future::Future, - io::{self, Cursor}, - str, + fmt, io, str, time::{Duration, SystemTime}, }; -use sync::{AsyncishRead, SyncFuture}; -use tokio::io::AsyncReadExt; use tracing::{trace, warn}; pub type Oid = u32; @@ -31,7 +25,6 @@ pub const TEXT_OID: Oid = 25; #[derive(Debug)] pub enum FeMessage { - StartupPacket(FeStartupPacket), // Simple query. Query(Bytes), // Extended query protocol. @@ -191,260 +184,207 @@ pub struct FeExecuteMessage { #[derive(Debug)] pub struct FeCloseMessage; -/// Retry a read on EINTR -/// -/// This runs the enclosed expression, and if it returns -/// Err(io::ErrorKind::Interrupted), retries it. -macro_rules! retry_read { - ( $x:expr ) => { - loop { - match $x { - Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, - res => break res, - } - } - }; -} - -/// An error occured during connection being open. +/// An error occured while parsing or serializing raw stream into Postgres +/// messages. #[derive(thiserror::Error, Debug)] -pub enum ConnectionError { - /// IO error during writing to or reading from the connection socket. - #[error("Socket IO error: {0}")] - Socket(std::io::Error), - /// Invalid packet was received from client +pub enum ProtocolError { + /// Invalid packet was received from the client (e.g. unexpected message + /// type or broken len). #[error("Protocol error: {0}")] Protocol(String), - /// Failed to parse a protocol mesage + /// Failed to parse or, (unlikely), serialize a protocol message. #[error("Message parse error: {0}")] - MessageParse(anyhow::Error), + BadMessage(String), } -impl From for ConnectionError { - fn from(e: anyhow::Error) -> Self { - Self::MessageParse(e) - } -} - -impl ConnectionError { +impl ProtocolError { + /// Proxy stream.rs uses only io::Error; provide it. pub fn into_io_error(self) -> io::Error { - match self { - ConnectionError::Socket(io) => io, - other => io::Error::new(io::ErrorKind::Other, other.to_string()), - } + io::Error::new(io::ErrorKind::Other, self.to_string()) } } impl FeMessage { - /// Read one message from the stream. - /// This function returns `Ok(None)` in case of EOF. - /// One way to handle this properly: + /// Read and parse one message from the `buf` input buffer. If there is at + /// least one valid message, returns it, advancing `buf`; redundant copies + /// are avoided, as thanks to `bytes` crate ptrs in parsed message point + /// directly into the `buf` (processed data is garbage collected after + /// parsed message is dropped). /// - /// ``` - /// # use std::io; - /// # use pq_proto::FeMessage; - /// # - /// # fn process_message(msg: FeMessage) -> anyhow::Result<()> { - /// # Ok(()) - /// # }; - /// # - /// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> { - /// while let Some(msg) = FeMessage::read(stream)? { - /// process_message(msg)?; - /// } + /// Returns None if `buf` doesn't contain enough data for a single message. + /// For efficiency, tries to reserve large enough space in `buf` for the + /// next message in this case to save the repeated calls. /// - /// Ok(()) - /// } - /// ``` - #[inline(never)] - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { - Self::read_fut(&mut AsyncishRead(stream)).wait() - } + /// Returns Error if message is malformed, the only possible ErrorKind is + /// InvalidInput. + // + // Inspired by rust-postgres Message::parse. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { + // Every message contains message type byte and 4 bytes len; can't do + // much without them. + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - /// Read one message from the stream. - /// See documentation for `Self::read`. - pub fn read_fut( - stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { - // We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof. - // SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and - // AsyncReadExt methods of the stream. - SyncFuture::new(async move { - // Each libpq message begins with a message type byte, followed by message length - // If the client closes the connection, return None. But if the client closes the - // connection in the middle of a message, we will return an error. - let tag = match retry_read!(stream.read_u8().await) { - Ok(b) => b, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let tag = buf[0]; + let len = (&buf[1..5]).read_u32::().unwrap(); + if len < 4 { + return Err(ProtocolError::Protocol(format!( + "invalid message length {}", + len + ))); + } - // The message length includes itself, so it better be at least 4. - let len = retry_read!(stream.read_u32().await) - .map_err(ConnectionError::Socket)? - .checked_sub(4) - .ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?; + // length field includes itself, but not message type. + let total_len = len as usize + 1; + if buf.len() < total_len { + // Don't have full message yet. + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - let body = { - let mut buffer = vec![0u8; len as usize]; - stream - .read_exact(&mut buffer) - .await - .map_err(ConnectionError::Socket)?; - Bytes::from(buffer) - }; + // got the message, advance buffer + let mut msg = buf.split_to(total_len).freeze(); + msg.advance(5); // consume message type and len - match tag { - b'Q' => Ok(Some(FeMessage::Query(body))), - b'P' => Ok(Some(FeParseMessage::parse(body)?)), - b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), - b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), - b'B' => Ok(Some(FeBindMessage::parse(body)?)), - b'C' => Ok(Some(FeCloseMessage::parse(body)?)), - b'S' => Ok(Some(FeMessage::Sync)), - b'X' => Ok(Some(FeMessage::Terminate)), - b'd' => Ok(Some(FeMessage::CopyData(body))), - b'c' => Ok(Some(FeMessage::CopyDone)), - b'f' => Ok(Some(FeMessage::CopyFail)), - b'p' => Ok(Some(FeMessage::PasswordMessage(body))), - tag => { - return Err(ConnectionError::Protocol(format!( - "unknown message tag: {tag},'{body:?}'" - ))) - } + match tag { + b'Q' => Ok(Some(FeMessage::Query(msg))), + b'P' => Ok(Some(FeParseMessage::parse(msg)?)), + b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)), + b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)), + b'B' => Ok(Some(FeBindMessage::parse(msg)?)), + b'C' => Ok(Some(FeCloseMessage::parse(msg)?)), + b'S' => Ok(Some(FeMessage::Sync)), + b'X' => Ok(Some(FeMessage::Terminate)), + b'd' => Ok(Some(FeMessage::CopyData(msg))), + b'c' => Ok(Some(FeMessage::CopyDone)), + b'f' => Ok(Some(FeMessage::CopyFail)), + b'p' => Ok(Some(FeMessage::PasswordMessage(msg))), + tag => { + return Err(ProtocolError::Protocol(format!( + "unknown message tag: {tag},'{msg:?}'" + ))) } - }) + } } } impl FeStartupPacket { - /// Read startup message from the stream. - // XXX: It's tempting yet undesirable to accept `stream` by value, - // since such a change will cause user-supplied &mut references to be consumed - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { - Self::read_fut(&mut AsyncishRead(stream)).wait() - } - - /// Read startup message from the stream. - // XXX: It's tempting yet undesirable to accept `stream` by value, - // since such a change will cause user-supplied &mut references to be consumed - pub fn read_fut( - stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { + /// Read and parse startup message from the `buf` input buffer. It is + /// different from [`FeMessage::parse`] because startup messages don't have + /// message type byte; otherwise, its comments apply. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; const CANCEL_REQUEST_CODE: u32 = 5678; const NEGOTIATE_SSL_CODE: u32 = 5679; const NEGOTIATE_GSS_CODE: u32 = 5680; - SyncFuture::new(async move { - // Read length. If the connection is closed before reading anything (or before - // reading 4 bytes, to be precise), return None to indicate that the connection - // was closed. This matches the PostgreSQL server's behavior, which avoids noise - // in the log if the client opens connection but closes it immediately. - let len = match retry_read!(stream.read_u32().await) { - Ok(len) => len as usize, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + // need at least 4 bytes with packet len + if buf.len() < 4 { + let to_read = 4 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - #[allow(clippy::manual_range_contains)] - if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { - return Err(ConnectionError::Protocol(format!( - "invalid message length {len}" + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let len = (&buf[0..4]).read_u32::().unwrap() as usize; + if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { + return Err(ProtocolError::Protocol(format!( + "invalid startup packet message length {}", + len + ))); + } + + if buf.len() < len { + // Don't have full message yet. + let to_read = len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + // got the message, advance buffer + let mut msg = buf.split_to(len).freeze(); + msg.advance(4); // consume len + + let request_code = msg.get_u32(); + let req_hi = request_code >> 16; + let req_lo = request_code & ((1 << 16) - 1); + // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code. + let message = match (req_hi, req_lo) { + (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { + if msg.remaining() != 8 { + return Err(ProtocolError::BadMessage( + "CancelRequest message is malformed, backend PID / secret key missing" + .to_owned(), + )); + } + FeStartupPacket::CancelRequest(CancelKeyData { + backend_pid: msg.get_i32(), + cancel_key: msg.get_i32(), + }) + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { + // Requested upgrade to SSL (aka TLS) + FeStartupPacket::SslRequest + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { + // Requested upgrade to GSSAPI + FeStartupPacket::GssEncRequest + } + (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { + return Err(ProtocolError::Protocol(format!( + "Unrecognized request code {unrecognized_code}" ))); } + // TODO bail if protocol major_version is not 3? + (major_version, minor_version) => { + // StartupMessage - let request_code = - retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?; + // Parse pairs of null-terminated strings (key, value). + // See `postgres: ProcessStartupPacket, build_startup_packet`. + let mut tokens = str::from_utf8(&msg) + .map_err(|_e| { + ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned()) + })? + .strip_suffix('\0') // drop packet's own null + .ok_or_else(|| { + ProtocolError::Protocol( + "StartupMessage params: missing null terminator".to_string(), + ) + })? + .split_terminator('\0'); - // the rest of startup packet are params - let params_len = len - 8; - let mut params_bytes = vec![0u8; params_len]; - stream - .read_exact(params_bytes.as_mut()) - .await - .map_err(ConnectionError::Socket)?; + let mut params = HashMap::new(); + while let Some(name) = tokens.next() { + let value = tokens.next().ok_or_else(|| { + ProtocolError::Protocol( + "StartupMessage params: key without value".to_string(), + ) + })?; - // Parse params depending on request code - let req_hi = request_code >> 16; - let req_lo = request_code & ((1 << 16) - 1); - let message = match (req_hi, req_lo) { - (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { - if params_len != 8 { - return Err(ConnectionError::Protocol( - "expected 8 bytes for CancelRequest params".to_string(), - )); - } - let mut cursor = Cursor::new(params_bytes); - FeStartupPacket::CancelRequest(CancelKeyData { - backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - }) + params.insert(name.to_owned(), value.to_owned()); } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { - // Requested upgrade to SSL (aka TLS) - FeStartupPacket::SslRequest - } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { - // Requested upgrade to GSSAPI - FeStartupPacket::GssEncRequest - } - (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { - return Err(ConnectionError::Protocol(format!( - "Unrecognized request code {unrecognized_code}" - ))); - } - // TODO bail if protocol major_version is not 3? - (major_version, minor_version) => { - // Parse pairs of null-terminated strings (key, value). - // See `postgres: ProcessStartupPacket, build_startup_packet`. - let mut tokens = str::from_utf8(¶ms_bytes) - .context("StartupMessage params: invalid utf-8")? - .strip_suffix('\0') // drop packet's own null - .ok_or_else(|| { - ConnectionError::Protocol( - "StartupMessage params: missing null terminator".to_string(), - ) - })? - .split_terminator('\0'); - let mut params = HashMap::new(); - while let Some(name) = tokens.next() { - let value = tokens.next().ok_or_else(|| { - ConnectionError::Protocol( - "StartupMessage params: key without value".to_string(), - ) - })?; - - params.insert(name.to_owned(), value.to_owned()); - } - - FeStartupPacket::StartupMessage { - major_version, - minor_version, - params: StartupMessageParams { params }, - } + FeStartupPacket::StartupMessage { + major_version, + minor_version, + params: StartupMessageParams { params }, } - }; - - Ok(Some(FeMessage::StartupPacket(message))) - }) + } + }; + Ok(Some(message)) } } impl FeParseMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { // FIXME: the rust-postgres driver uses a named prepared statement // for copy_out(). We're not prepared to handle that correctly. For // now, just ignore the statement name, assuming that the client never @@ -452,55 +392,82 @@ impl FeParseMessage { let _pstmt_name = read_cstr(&mut buf)?; let query_string = read_cstr(&mut buf)?; + if buf.remaining() < 2 { + return Err(ProtocolError::BadMessage( + "Parse message is malformed, nparams missing".to_string(), + )); + } let nparams = buf.get_i16(); - ensure!(nparams == 0, "query params not implemented"); + if nparams != 0 { + return Err(ProtocolError::BadMessage( + "query params not implemented".to_string(), + )); + } Ok(FeMessage::Parse(FeParseMessage { query_string })) } } impl FeDescribeMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let kind = buf.get_u8(); let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - ensure!( - kind == b'S', - "only prepared statemement Describe is implemented" - ); + if kind != b'S' { + return Err(ProtocolError::BadMessage( + "only prepared statemement Describe is implemented".to_string(), + )); + } Ok(FeMessage::Describe(FeDescribeMessage { kind })) } } impl FeExecuteMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let portal_name = read_cstr(&mut buf)?; + if buf.remaining() < 4 { + return Err(ProtocolError::BadMessage( + "FeExecuteMessage message is malformed, maxrows missing".to_string(), + )); + } let maxrows = buf.get_i32(); - ensure!(portal_name.is_empty(), "named portals not implemented"); - ensure!(maxrows == 0, "row limit in Execute message not implemented"); + if !portal_name.is_empty() { + return Err(ProtocolError::BadMessage( + "named portals not implemented".to_string(), + )); + } + if maxrows != 0 { + return Err(ProtocolError::BadMessage( + "row limit in Execute message not implemented".to_string(), + )); + } Ok(FeMessage::Execute(FeExecuteMessage { maxrows })) } } impl FeBindMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let portal_name = read_cstr(&mut buf)?; let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - ensure!(portal_name.is_empty(), "named portals not implemented"); + if !portal_name.is_empty() { + return Err(ProtocolError::BadMessage( + "named portals not implemented".to_string(), + )); + } Ok(FeMessage::Bind(FeBindMessage)) } } impl FeCloseMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let _kind = buf.get_u8(); let _pstmt_or_portal_name = read_cstr(&mut buf)?; @@ -529,6 +496,7 @@ pub enum BeMessage<'a> { CloseComplete, // None means column is NULL DataRow(&'a [Option<&'a [u8]>]), + // None errcode means internal_error will be sent. ErrorResponse(&'a str, Option<&'a [u8; 5]>), /// Single byte - used in response to SSLRequest/GSSENCRequest. EncryptionResponse(bool), @@ -559,6 +527,11 @@ impl<'a> BeMessage<'a> { value: b"UTF8", }; + pub const INTEGER_DATETIMES: Self = Self::ParameterStatus { + name: b"integer_datetimes", + value: b"on", + }; + /// Build a [`BeMessage::ParameterStatus`] holding the server version. pub fn server_version(version: &'a str) -> Self { Self::ParameterStatus { @@ -637,7 +610,7 @@ impl RowDescriptor<'_> { #[derive(Debug)] pub struct XLogDataBody<'a> { pub wal_start: u64, - pub wal_end: u64, + pub wal_end: u64, // current end of WAL on the server pub timestamp: i64, pub data: &'a [u8], } @@ -677,12 +650,11 @@ fn write_body(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R { } /// Safe write of s into buf as cstring (String in the protocol). -fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { +fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> { let bytes = s.as_ref(); if bytes.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", + return Err(ProtocolError::BadMessage( + "string contains embedded null".to_owned(), )); } buf.put_slice(bytes); @@ -690,22 +662,27 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { Ok(()) } -fn read_cstr(buf: &mut Bytes) -> anyhow::Result { - let pos = buf.iter().position(|x| *x == 0); - let result = buf.split_to(pos.context("missing terminator")?); +/// Read cstring from buf, advancing it. +fn read_cstr(buf: &mut Bytes) -> Result { + let pos = buf + .iter() + .position(|x| *x == 0) + .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?; + let result = buf.split_to(pos); buf.advance(1); // drop the null terminator Ok(result) } pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000"; +pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000"; impl<'a> BeMessage<'a> { - /// Write message to the given buf. - // Unlike the reading side, we use BytesMut - // here as msg len precedes its body and it is handy to write it down first - // and then fill the length. With Write we would have to either calc it - // manually or have one more buffer. - pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> { + /// Serialize `message` to the given `buf`. + /// Apart from smart memory managemet, BytesMut is good here as msg len + /// precedes its body and it is handy to write it down first and then fill + /// the length. With Write we would have to either calc it manually or have + /// one more buffer. + pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> { match message { BeMessage::AuthenticationOk => { buf.put_u8(b'R'); @@ -750,7 +727,7 @@ impl<'a> BeMessage<'a> { buf.put_slice(extra); } } - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -841,7 +818,7 @@ impl<'a> BeMessage<'a> { BeMessage::ErrorResponse(error_msg, pg_error_code) => { // 'E' signalizes ErrorResponse messages buf.put_u8(b'E'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_u8(b'S'); // severity buf.put_slice(b"ERROR\0"); @@ -854,7 +831,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg, buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -866,7 +843,7 @@ impl<'a> BeMessage<'a> { // 'N' signalizes NoticeResponse messages buf.put_u8(b'N'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_u8(b'S'); // severity buf.put_slice(b"NOTICE\0"); @@ -877,7 +854,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg.as_bytes(), buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -921,7 +898,7 @@ impl<'a> BeMessage<'a> { BeMessage::RowDescription(rows) => { buf.put_u8(b'T'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_i16(rows.len() as i16); // # of fields for row in rows.iter() { write_cstr(row.name, buf)?; @@ -932,7 +909,7 @@ impl<'a> BeMessage<'a> { buf.put_i32(-1); /* typmod */ buf.put_i16(0); /* format code */ } - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -999,7 +976,7 @@ impl ReplicationFeedback { // null-terminated string - key, // uint32 - value length in bytes // value itself - pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> { + pub fn serialize(&self, buf: &mut BytesMut) { buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys buf.put_slice(b"current_timeline_size\0"); buf.put_i32(8); @@ -1024,7 +1001,6 @@ impl ReplicationFeedback { buf.put_slice(b"ps_replytime\0"); buf.put_i32(8); buf.put_i64(timestamp); - Ok(()) } // Deserialize ReplicationFeedback message @@ -1092,7 +1068,7 @@ mod tests { // because it is rounded up to microseconds during serialization. rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000); let mut data = BytesMut::new(); - rf.serialize(&mut data).unwrap(); + rf.serialize(&mut data); let rf_parsed = ReplicationFeedback::parse(data.freeze()); assert_eq!(rf, rf_parsed); @@ -1107,7 +1083,7 @@ mod tests { // because it is rounded up to microseconds during serialization. rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000); let mut data = BytesMut::new(); - rf.serialize(&mut data).unwrap(); + rf.serialize(&mut data); // Add an extra field to the buffer and adjust number of keys if let Some(first) = data.first_mut() { @@ -1149,15 +1125,6 @@ mod tests { let params = make_params("foo\\ bar \\ \\\\ baz\\ lol"); assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]); } - - // Make sure that `read` is sync/async callable - async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) { - let _ = FeMessage::read(&mut [].as_ref()); - let _ = FeMessage::read_fut(stream).await; - - let _ = FeStartupPacket::read(&mut [].as_ref()); - let _ = FeStartupPacket::read_fut(stream).await; - } } fn terminate_code(code: &[u8; 5]) -> [u8; 6] { diff --git a/libs/pq_proto/src/sync.rs b/libs/pq_proto/src/sync.rs deleted file mode 100644 index b7ff1fb70b..0000000000 --- a/libs/pq_proto/src/sync.rs +++ /dev/null @@ -1,179 +0,0 @@ -use pin_project_lite::pin_project; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::{io, task}; - -pin_project! { - /// We use this future to mark certain methods - /// as callable in both sync and async modes. - #[repr(transparent)] - pub struct SyncFuture { - #[pin] - inner: T, - _marker: PhantomData, - } -} - -/// This wrapper lets us synchronously wait for inner future's completion -/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**. -/// For instance, `S` may be substituted with types implementing -/// [`tokio::io::AsyncRead`], but it's not the only viable option. -impl SyncFuture { - /// NOTE: caller should carefully pick a type for `S`, - /// because we don't want to enable [`SyncFuture::wait`] when - /// it's in fact impossible to run the future synchronously. - /// Violation of this contract will not cause UB, but - /// panics and async event loop freezes won't please you. - /// - /// Example: - /// - /// ``` - /// # use pq_proto::sync::SyncFuture; - /// # use std::future::Future; - /// # use tokio::io::AsyncReadExt; - /// # - /// // Parse a pair of numbers from a stream - /// pub fn parse_pair( - /// stream: &mut Reader, - /// ) -> SyncFuture> + '_> - /// where - /// Reader: tokio::io::AsyncRead + Unpin, - /// { - /// // If `Reader` is a `SyncProof`, this will give caller - /// // an opportunity to use `SyncFuture::wait`, because - /// // `.await` will always result in `Poll::Ready`. - /// SyncFuture::new(async move { - /// let x = stream.read_u32().await?; - /// let y = stream.read_u64().await?; - /// Ok((x, y)) - /// }) - /// } - /// ``` - pub fn new(inner: T) -> Self { - Self { - inner, - _marker: PhantomData, - } - } -} - -impl Future for SyncFuture { - type Output = T::Output; - - /// In async code, [`SyncFuture`] behaves like a regular wrapper. - #[inline(always)] - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { - self.project().inner.poll(cx) - } -} - -/// Postulates that we can call [`SyncFuture::wait`]. -/// If implementer is also a [`Future`], it should always -/// return [`task::Poll::Ready`] from [`Future::poll`]. -/// -/// Each implementation should document which futures -/// specifically are being declared sync-proof. -pub trait SyncPostulate {} - -impl SyncPostulate for &T {} -impl SyncPostulate for &mut T {} - -impl SyncFuture { - /// Synchronously wait for future completion. - pub fn wait(mut self) -> T::Output { - const RAW_WAKER: task::RawWaker = task::RawWaker::new( - std::ptr::null(), - &task::RawWakerVTable::new( - |_| RAW_WAKER, - |_| panic!("SyncFuture: failed to wake"), - |_| panic!("SyncFuture: failed to wake by ref"), - |_| { /* drop is no-op */ }, - ), - ); - - // SAFETY: We never move `self` during this call; - // furthermore, it will be dropped in the end regardless of panics - let this = unsafe { Pin::new_unchecked(&mut self) }; - - // SAFETY: This waker doesn't do anything apart from panicking - let waker = unsafe { task::Waker::from_raw(RAW_WAKER) }; - let context = &mut task::Context::from_waker(&waker); - - match this.poll(context) { - task::Poll::Ready(res) => res, - _ => panic!("SyncFuture: unexpected pending!"), - } - } -} - -/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`], -/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`]. -/// NOTE: you **should not** use this in async code. -#[repr(transparent)] -pub struct AsyncishRead(pub T); - -/// This lets us call [`SyncFuture, _>::wait`], -/// and allows the future to await on any of the [`AsyncRead`] -/// and [`AsyncReadExt`] methods on `AsyncishRead`. -impl SyncPostulate for AsyncishRead {} - -impl tokio::io::AsyncRead for AsyncishRead { - #[inline(always)] - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> task::Poll> { - task::Poll::Ready( - // `Read::read` will block, meaning we don't need a real event loop! - self.0 - .read(buf.initialize_unfilled()) - .map(|sz| buf.advance(sz)), - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - // async helper(stream: &mut impl AsyncRead) -> io::Result - fn bytes_add( - stream: &mut Reader, - ) -> SyncFuture> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { - SyncFuture::new(async move { - let a = stream.read_u32().await?; - let b = stream.read_u32().await?; - Ok(a + b) - }) - } - - #[test] - fn test_sync() { - let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat(); - let res = bytes_add(&mut AsyncishRead(&mut &bytes[..])) - .wait() - .unwrap(); - assert_eq!(res, 300); - } - - // We need a single-threaded executor for this test - #[tokio::test(flavor = "current_thread")] - async fn test_async() { - let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap(); - - let write = async move { - tx.write_u32(100).await?; - tx.write_u32(200).await?; - Ok(()) - }; - - let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap(); - assert_eq!(res, 300); - } -} diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 1091a8bd5c..901f849801 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static { } pub struct Download { - pub download_stream: Pin>, + pub download_stream: Pin>, /// Extra key-value data, associated with the current remote file. pub metadata: Option, } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 6acdb6fa53..84b82472c6 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -12,42 +12,38 @@ anyhow.workspace = true bincode.workspace = true bytes.workspace = true heapless.workspace = true +hex = { workspace = true, features = ["serde"] } hyper = { workspace = true, features = ["full"] } futures = { workspace = true} +jsonwebtoken.workspace = true +nix.workspace = true +once_cell.workspace = true routerify.workspace = true serde.workspace = true serde_json.workspace = true +signal-hook.workspace = true thiserror.workspace = true tokio.workspace = true tokio-rustls.workspace = true tracing.workspace = true tracing-subscriber = { workspace = true, features = ["json"] } -nix.workspace = true -signal-hook.workspace = true rand.workspace = true -jsonwebtoken.workspace = true -hex = { workspace = true, features = ["serde"] } rustls.workspace = true -rustls-split.workspace = true -git-version.workspace = true serde_with.workspace = true -once_cell.workspace = true strum.workspace = true strum_macros.workspace = true - -metrics.workspace = true -pq_proto.workspace = true - -workspace_hack.workspace = true url.workspace = true uuid = { version = "1.2", features = ["v4", "serde"] } + +metrics.workspace = true +workspace_hack.workspace = true + [dev-dependencies] byteorder.workspace = true bytes.workspace = true +criterion.workspace = true hex-literal.workspace = true tempfile.workspace = true -criterion.workspace = true -rustls-pemfile.workspace = true [[bench]] name = "benchmarks" diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 9ddd702c72..acb5273943 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -13,8 +13,6 @@ pub mod simple_rcu; pub mod vec_map; pub mod bin_ser; -pub mod postgres_backend; -pub mod postgres_backend_async; // helper functions for creating and fsyncing pub mod crashsafe; @@ -27,9 +25,6 @@ pub mod id; // http endpoint utils pub mod http; -// socket splitting utils -pub mod sock_split; - // common log initialisation routine pub mod logging; diff --git a/libs/utils/src/postgres_backend.rs b/libs/utils/src/postgres_backend.rs deleted file mode 100644 index f3e3835bda..0000000000 --- a/libs/utils/src/postgres_backend.rs +++ /dev/null @@ -1,485 +0,0 @@ -//! Server-side synchronous Postgres connection, as limited as we need. -//! To use, create PostgresBackend and run() it, passing the Handler -//! implementation determining how to process the queries. Currently its API -//! is rather narrow, but we can extend it once required. - -use crate::postgres_backend_async::{log_query_error, short_error, QueryError}; -use crate::sock_split::{BidiStream, ReadStream, WriteStream}; -use anyhow::Context; -use bytes::{Bytes, BytesMut}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket}; -use serde::{Deserialize, Serialize}; -use std::fmt; -use std::io::{self, Write}; -use std::net::{Shutdown, SocketAddr, TcpStream}; -use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; -use tracing::*; - -pub trait Handler { - /// Handle single query. - /// postgres_backend will issue ReadyForQuery after calling this (this - /// might be not what we want after CopyData streaming, but currently we don't - /// care). - fn process_query( - &mut self, - pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError>; - - /// Called on startup packet receival, allows to process params. - /// - /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users - /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow - /// to override whole init logic in implementations. - fn startup( - &mut self, - _pgb: &mut PostgresBackend, - _sm: &FeStartupPacket, - ) -> Result<(), QueryError> { - Ok(()) - } - - /// Check auth jwt - fn check_auth_jwt( - &mut self, - _pgb: &mut PostgresBackend, - _jwt_response: &[u8], - ) -> Result<(), QueryError> { - Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) - } - - fn is_shutdown_requested(&self) -> bool { - false - } -} - -/// PostgresBackend protocol state. -/// XXX: The order of the constructors matters. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum ProtoState { - Initialization, - Encrypted, - Authentication, - Established, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] -pub enum AuthType { - Trust, - // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT - NeonJWT, -} - -impl FromStr for AuthType { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s { - "Trust" => Ok(Self::Trust), - "NeonJWT" => Ok(Self::NeonJWT), - _ => anyhow::bail!("invalid value \"{s}\" for auth type"), - } - } -} - -impl fmt::Display for AuthType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - AuthType::Trust => "Trust", - AuthType::NeonJWT => "NeonJWT", - }) - } -} - -#[derive(Clone, Copy)] -pub enum ProcessMsgResult { - Continue, - Break, -} - -/// Always-writeable sock_split stream. -/// May not be readable. See [`PostgresBackend::take_stream_in`] -pub enum Stream { - Bidirectional(BidiStream), - WriteOnly(WriteStream), -} - -impl Stream { - fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how), - Self::WriteOnly(write_stream) => write_stream.shutdown(how), - } - } -} - -impl io::Write for Stream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.write(buf), - Self::WriteOnly(write_stream) => write_stream.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.flush(), - Self::WriteOnly(write_stream) => write_stream.flush(), - } - } -} - -pub struct PostgresBackend { - stream: Option, - // Output buffer. c.f. BeMessage::write why we are using BytesMut here. - buf_out: BytesMut, - - pub state: ProtoState, - - auth_type: AuthType, - - peer_addr: SocketAddr, - pub tls_config: Option>, -} - -pub fn query_from_cstring(query_string: Bytes) -> Vec { - let mut query_string = query_string.to_vec(); - if let Some(ch) = query_string.last() { - if *ch == 0 { - query_string.pop(); - } - } - query_string -} - -// Helper function for socket read loops -pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool { - for cause in error.chain() { - if let Some(io_error) = cause.downcast_ref::() { - if io_error.kind() == std::io::ErrorKind::WouldBlock { - return true; - } - } - } - false -} - -// Cast a byte slice to a string slice, dropping null terminator if there's one. -fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { - let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); - std::str::from_utf8(without_null).map_err(|e| e.into()) -} - -impl PostgresBackend { - pub fn new( - socket: TcpStream, - auth_type: AuthType, - tls_config: Option>, - set_read_timeout: bool, - ) -> io::Result { - let peer_addr = socket.peer_addr()?; - if set_read_timeout { - socket - .set_read_timeout(Some(Duration::from_secs(5))) - .unwrap(); - } - - Ok(Self { - stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))), - buf_out: BytesMut::with_capacity(10 * 1024), - state: ProtoState::Initialization, - auth_type, - tls_config, - peer_addr, - }) - } - - pub fn into_stream(self) -> Stream { - self.stream.unwrap() - } - - /// Get direct reference (into the Option) to the read stream. - fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> { - match &mut self.stream { - Some(Stream::Bidirectional(stream)) => Ok(stream), - _ => anyhow::bail!("reader taken"), - } - } - - pub fn get_peer_addr(&self) -> &SocketAddr { - &self.peer_addr - } - - pub fn take_stream_in(&mut self) -> Option { - let stream = self.stream.take(); - match stream { - Some(Stream::Bidirectional(bidi_stream)) => { - let (read, write) = bidi_stream.split(); - self.stream = Some(Stream::WriteOnly(write)); - Some(read) - } - stream => { - self.stream = stream; - None - } - } - } - - /// Read full message or return None if connection is closed. - pub fn read_message(&mut self) -> Result, QueryError> { - let (state, stream) = (self.state, self.get_stream_in()?); - - use ProtoState::*; - match state { - Initialization | Encrypted => FeStartupPacket::read(stream), - Authentication | Established => FeMessage::read(stream), - } - .map_err(QueryError::from) - } - - /// Write message into internal output buffer. - pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buf_out, message)?; - Ok(self) - } - - /// Flush output buffer into the socket. - pub fn flush(&mut self) -> io::Result<&mut Self> { - let stream = self.stream.as_mut().unwrap(); - stream.write_all(&self.buf_out)?; - self.buf_out.clear(); - Ok(self) - } - - /// Write message into internal buffer and flush it. - pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush() - } - - // Wrapper for run_message_loop() that shuts down socket when we are done - pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> { - let ret = self.run_message_loop(handler); - if let Some(stream) = self.stream.as_mut() { - let _ = stream.shutdown(Shutdown::Both); - } - ret - } - - fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> { - trace!("postgres backend to {:?} started", self.peer_addr); - - let mut unnamed_query_string = Bytes::new(); - - while !handler.is_shutdown_requested() { - match self.read_message() { - Ok(message) => { - if let Some(msg) = message { - trace!("got message {msg:?}"); - - match self.process_message(handler, msg, &mut unnamed_query_string)? { - ProcessMsgResult::Continue => continue, - ProcessMsgResult::Break => break, - } - } else { - break; - } - } - Err(e) => { - if let QueryError::Other(e) = &e { - if is_socket_read_timed_out(e) { - continue; - } - } - return Err(e); - } - } - } - - trace!("postgres backend to {:?} exited", self.peer_addr); - Ok(()) - } - - pub fn start_tls(&mut self) -> anyhow::Result<()> { - match self.stream.take() { - Some(Stream::Bidirectional(bidi_stream)) => { - let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?; - self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?)); - Ok(()) - } - stream => { - self.stream = stream; - anyhow::bail!("can't start TLs without bidi stream"); - } - } - } - - fn process_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - unnamed_query_string: &mut Bytes, - ) -> Result { - // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth - // TODO: change that to proper top-level match of protocol state with separate message handling for each state - if self.state < ProtoState::Established - && !matches!( - msg, - FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_) - ) - { - return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); - } - - let have_tls = self.tls_config.is_some(); - match msg { - FeMessage::StartupPacket(m) => { - trace!("got startup message {m:?}"); - - match m { - FeStartupPacket::SslRequest => { - debug!("SSL requested"); - - self.write_message(&BeMessage::EncryptionResponse(have_tls))?; - if have_tls { - self.start_tls()?; - self.state = ProtoState::Encrypted; - } - } - FeStartupPacket::GssEncRequest => { - debug!("GSS requested"); - self.write_message(&BeMessage::EncryptionResponse(false))?; - } - FeStartupPacket::StartupMessage { .. } => { - if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message(&BeMessage::ErrorResponse( - "must connect with TLS", - None, - ))?; - return Err(QueryError::Other(anyhow::anyhow!( - "client did not connect with TLS" - ))); - } - - // NB: startup() may change self.auth_type -- we are using that in proxy code - // to bypass auth for new users. - handler.startup(self, &m)?; - - match self.auth_type { - AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - // The async python driver requires a valid server_version - .write_message_noflush(&BeMessage::server_version("14.1"))? - .write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - AuthType::NeonJWT => { - self.write_message(&BeMessage::AuthenticationCleartextPassword)?; - self.state = ProtoState::Authentication; - } - } - } - FeStartupPacket::CancelRequest { .. } => { - return Ok(ProcessMsgResult::Break); - } - } - } - - FeMessage::PasswordMessage(m) => { - trace!("got password message '{:?}'", m); - - assert!(self.state == ProtoState::Authentication); - - match self.auth_type { - AuthType::Trust => unreachable!(), - AuthType::NeonJWT => { - let (_, jwt_response) = m.split_last().context("protocol violation")?; - - if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - return Err(e); - } - } - } - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - .write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - - FeMessage::Query(body) => { - // remove null terminator - let query_string = cstr_to_str(&body)?; - - trace!("got query {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string) { - log_query_error(query_string, &e); - let short_error = short_error(&e); - self.write_message_noflush(&BeMessage::ErrorResponse( - &short_error, - Some(e.pg_error_code()), - ))?; - } - self.write_message(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Parse(m) => { - *unnamed_query_string = m.query_string; - self.write_message(&BeMessage::ParseComplete)?; - } - - FeMessage::Describe(_) => { - self.write_message_noflush(&BeMessage::ParameterDescription)? - .write_message(&BeMessage::NoData)?; - } - - FeMessage::Bind(_) => { - self.write_message(&BeMessage::BindComplete)?; - } - - FeMessage::Close(_) => { - self.write_message(&BeMessage::CloseComplete)?; - } - - FeMessage::Execute(_) => { - let query_string = cstr_to_str(unnamed_query_string)?; - trace!("got execute {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string) { - log_query_error(query_string, &e); - self.write_message(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - } - // NOTE there is no ReadyForQuery message. This handler is used - // for basebackup and it uses CopyOut which doesn't require - // ReadyForQuery message and backend just switches back to - // processing mode after sending CopyDone or ErrorResponse. - } - - FeMessage::Sync => { - self.write_message(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Terminate => { - return Ok(ProcessMsgResult::Break); - } - - // We prefer explicit pattern matching to wildcards, because - // this helps us spot the places where new variants are missing - FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message type: {msg:?}" - ))); - } - } - - Ok(ProcessMsgResult::Continue) - } -} diff --git a/libs/utils/src/postgres_backend_async.rs b/libs/utils/src/postgres_backend_async.rs deleted file mode 100644 index 442b06ed01..0000000000 --- a/libs/utils/src/postgres_backend_async.rs +++ /dev/null @@ -1,636 +0,0 @@ -//! Server-side asynchronous Postgres connection, as limited as we need. -//! To use, create PostgresBackend and run() it, passing the Handler -//! implementation determining how to process the queries. Currently its API -//! is rather narrow, but we can extend it once required. - -use crate::postgres_backend::AuthType; -use anyhow::Context; -use bytes::{Buf, Bytes, BytesMut}; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR}; -use std::io; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::task::Poll; -use std::{future::Future, task::ready}; -use tracing::{debug, error, info, trace}; - -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; -use tokio_rustls::TlsAcceptor; - -pub fn is_expected_io_error(e: &io::Error) -> bool { - use io::ErrorKind::*; - matches!( - e.kind(), - ConnectionRefused | ConnectionAborted | ConnectionReset - ) -} - -/// An error, occurred during query processing: -/// either during the connection ([`ConnectionError`]) or before/after it. -#[derive(thiserror::Error, Debug)] -pub enum QueryError { - /// The connection was lost while processing the query. - #[error(transparent)] - Disconnected(#[from] ConnectionError), - /// Some other error - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -impl From for QueryError { - fn from(e: io::Error) -> Self { - Self::Disconnected(ConnectionError::Socket(e)) - } -} - -impl QueryError { - pub fn pg_error_code(&self) -> &'static [u8; 5] { - match self { - Self::Disconnected(_) => b"08006", // connection failure - Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error - } - } -} - -#[async_trait::async_trait] -pub trait Handler { - /// Handle single query. - /// postgres_backend will issue ReadyForQuery after calling this (this - /// might be not what we want after CopyData streaming, but currently we don't - /// care). - async fn process_query( - &mut self, - pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError>; - - /// Called on startup packet receival, allows to process params. - /// - /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users - /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow - /// to override whole init logic in implementations. - fn startup( - &mut self, - _pgb: &mut PostgresBackend, - _sm: &FeStartupPacket, - ) -> Result<(), QueryError> { - Ok(()) - } - - /// Check auth jwt - fn check_auth_jwt( - &mut self, - _pgb: &mut PostgresBackend, - _jwt_response: &[u8], - ) -> Result<(), QueryError> { - Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) - } -} - -/// PostgresBackend protocol state. -/// XXX: The order of the constructors matters. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum ProtoState { - Initialization, - Encrypted, - Authentication, - Established, - Closed, -} - -#[derive(Clone, Copy)] -pub enum ProcessMsgResult { - Continue, - Break, -} - -/// Always-writeable sock_split stream. -/// May not be readable. See [`PostgresBackend::take_stream_in`] -pub enum Stream { - Unencrypted(BufReader), - Tls(Box>>), - Broken, -} - -impl AsyncWrite for Stream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf), - Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), - Self::Broken => unreachable!(), - } - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx), - Self::Tls(stream) => Pin::new(stream).poll_flush(cx), - Self::Broken => unreachable!(), - } - } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx), - Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), - Self::Broken => unreachable!(), - } - } -} -impl AsyncRead for Stream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf), - Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), - Self::Broken => unreachable!(), - } - } -} - -pub struct PostgresBackend { - stream: Stream, - - // Output buffer. c.f. BeMessage::write why we are using BytesMut here. - // The data between 0 and "current position" as tracked by the bytes::Buf - // implementation of BytesMut, have already been written. - buf_out: BytesMut, - - pub state: ProtoState, - - auth_type: AuthType, - - peer_addr: SocketAddr, - pub tls_config: Option>, -} - -pub fn query_from_cstring(query_string: Bytes) -> Vec { - let mut query_string = query_string.to_vec(); - if let Some(ch) = query_string.last() { - if *ch == 0 { - query_string.pop(); - } - } - query_string -} - -// Cast a byte slice to a string slice, dropping null terminator if there's one. -fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { - let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); - std::str::from_utf8(without_null).map_err(|e| e.into()) -} - -impl PostgresBackend { - pub fn new( - socket: tokio::net::TcpStream, - auth_type: AuthType, - tls_config: Option>, - ) -> io::Result { - let peer_addr = socket.peer_addr()?; - - Ok(Self { - stream: Stream::Unencrypted(BufReader::new(socket)), - buf_out: BytesMut::with_capacity(10 * 1024), - state: ProtoState::Initialization, - auth_type, - tls_config, - peer_addr, - }) - } - - pub fn get_peer_addr(&self) -> &SocketAddr { - &self.peer_addr - } - - /// Read full message or return None if connection is closed. - pub async fn read_message(&mut self) -> Result, QueryError> { - use ProtoState::*; - match self.state { - Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await, - Authentication | Established => FeMessage::read_fut(&mut self.stream).await, - Closed => Ok(None), - } - .map_err(QueryError::from) - } - - /// Flush output buffer into the socket. - pub async fn flush(&mut self) -> io::Result<()> { - while self.buf_out.has_remaining() { - let bytes_written = self.stream.write(self.buf_out.chunk()).await?; - self.buf_out.advance(bytes_written); - } - self.buf_out.clear(); - Ok(()) - } - - /// Write message into internal output buffer. - pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buf_out, message)?; - Ok(self) - } - - /// Returns an AsyncWrite implementation that wraps all the data written - /// to it in CopyData messages, and writes them to the connection - /// - /// The caller is responsible for sending CopyOutResponse and CopyDone messages. - pub fn copyout_writer(&mut self) -> CopyDataWriter { - CopyDataWriter { pgb: self } - } - - /// A polling function that tries to write all the data from 'buf_out' to the - /// underlying stream. - fn poll_write_buf( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - while self.buf_out.has_remaining() { - match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) { - Ok(bytes_written) => self.buf_out.advance(bytes_written), - Err(err) => return Poll::Ready(Err(err)), - } - } - Poll::Ready(Ok(())) - } - - fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_flush(cx) - } - - // Wrapper for run_message_loop() that shuts down socket when we are done - pub async fn run( - mut self, - handler: &mut impl Handler, - shutdown_watcher: F, - ) -> Result<(), QueryError> - where - F: Fn() -> S, - S: Future, - { - let ret = self.run_message_loop(handler, shutdown_watcher).await; - let _ = self.stream.shutdown(); - ret - } - - async fn run_message_loop( - &mut self, - handler: &mut impl Handler, - shutdown_watcher: F, - ) -> Result<(), QueryError> - where - F: Fn() -> S, - S: Future, - { - trace!("postgres backend to {:?} started", self.peer_addr); - - tokio::select!( - biased; - - _ = shutdown_watcher() => { - // We were requested to shut down. - tracing::info!("shutdown request received during handshake"); - return Ok(()) - }, - - result = async { - while self.state < ProtoState::Established { - if let Some(msg) = self.read_message().await? { - trace!("got message {msg:?} during handshake"); - - match self.process_handshake_message(handler, msg).await? { - ProcessMsgResult::Continue => { - self.flush().await?; - continue; - } - ProcessMsgResult::Break => { - trace!("postgres backend to {:?} exited during handshake", self.peer_addr); - return Ok(()); - } - } - } else { - trace!("postgres backend to {:?} exited during handshake", self.peer_addr); - return Ok(()); - } - } - Ok::<(), QueryError>(()) - } => { - // Handshake complete. - result?; - } - ); - - // Authentication completed - let mut query_string = Bytes::new(); - while let Some(msg) = tokio::select!( - biased; - _ = shutdown_watcher() => { - // We were requested to shut down. - tracing::info!("shutdown request received in run_message_loop"); - Ok(None) - }, - msg = self.read_message() => { msg }, - )? { - trace!("got message {:?}", msg); - - let result = self.process_message(handler, msg, &mut query_string).await; - self.flush().await?; - match result? { - ProcessMsgResult::Continue => { - self.flush().await?; - continue; - } - ProcessMsgResult::Break => break, - } - } - - trace!("postgres backend to {:?} exited", self.peer_addr); - Ok(()) - } - - async fn start_tls(&mut self) -> anyhow::Result<()> { - if let Stream::Unencrypted(plain_stream) = - std::mem::replace(&mut self.stream, Stream::Broken) - { - let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap()); - let tls_stream = acceptor.accept(plain_stream).await?; - - self.stream = Stream::Tls(Box::new(tls_stream)); - return Ok(()); - }; - anyhow::bail!("TLS already started"); - } - - async fn process_handshake_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - ) -> Result { - assert!(self.state < ProtoState::Established); - let have_tls = self.tls_config.is_some(); - match msg { - FeMessage::StartupPacket(m) => { - trace!("got startup message {m:?}"); - - match m { - FeStartupPacket::SslRequest => { - debug!("SSL requested"); - - self.write_message_noflush(&BeMessage::EncryptionResponse(have_tls))?; - if have_tls { - self.start_tls().await?; - self.state = ProtoState::Encrypted; - } - } - FeStartupPacket::GssEncRequest => { - debug!("GSS requested"); - self.write_message_noflush(&BeMessage::EncryptionResponse(false))?; - } - FeStartupPacket::StartupMessage { .. } => { - if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message_noflush(&BeMessage::ErrorResponse( - "must connect with TLS", - None, - ))?; - return Err(QueryError::Other(anyhow::anyhow!( - "client did not connect with TLS" - ))); - } - - // NB: startup() may change self.auth_type -- we are using that in proxy code - // to bypass auth for new users. - handler.startup(self, &m)?; - - match self.auth_type { - AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - // The async python driver requires a valid server_version - .write_message_noflush(&BeMessage::server_version("14.1"))? - .write_message_noflush(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - AuthType::NeonJWT => { - self.write_message_noflush( - &BeMessage::AuthenticationCleartextPassword, - )?; - self.state = ProtoState::Authentication; - } - } - } - FeStartupPacket::CancelRequest { .. } => { - self.state = ProtoState::Closed; - return Ok(ProcessMsgResult::Break); - } - } - } - - FeMessage::PasswordMessage(m) => { - trace!("got password message '{:?}'", m); - - assert!(self.state == ProtoState::Authentication); - - match self.auth_type { - AuthType::Trust => unreachable!(), - AuthType::NeonJWT => { - let (_, jwt_response) = m.split_last().context("protocol violation")?; - - if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message_noflush(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - return Err(e); - } - } - } - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - .write_message_noflush(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - - _ => { - self.state = ProtoState::Closed; - return Ok(ProcessMsgResult::Break); - } - } - Ok(ProcessMsgResult::Continue) - } - - async fn process_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - unnamed_query_string: &mut Bytes, - ) -> Result { - // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth - // TODO: change that to proper top-level match of protocol state with separate message handling for each state - assert!(self.state == ProtoState::Established); - - match msg { - FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => { - return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); - } - - FeMessage::Query(body) => { - // remove null terminator - let query_string = cstr_to_str(&body)?; - - trace!("got query {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string).await { - log_query_error(query_string, &e); - let short_error = short_error(&e); - self.write_message_noflush(&BeMessage::ErrorResponse( - &short_error, - Some(e.pg_error_code()), - ))?; - } - self.write_message_noflush(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Parse(m) => { - *unnamed_query_string = m.query_string; - self.write_message_noflush(&BeMessage::ParseComplete)?; - } - - FeMessage::Describe(_) => { - self.write_message_noflush(&BeMessage::ParameterDescription)? - .write_message_noflush(&BeMessage::NoData)?; - } - - FeMessage::Bind(_) => { - self.write_message_noflush(&BeMessage::BindComplete)?; - } - - FeMessage::Close(_) => { - self.write_message_noflush(&BeMessage::CloseComplete)?; - } - - FeMessage::Execute(_) => { - let query_string = cstr_to_str(unnamed_query_string)?; - trace!("got execute {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string).await { - log_query_error(query_string, &e); - self.write_message_noflush(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - } - // NOTE there is no ReadyForQuery message. This handler is used - // for basebackup and it uses CopyOut which doesn't require - // ReadyForQuery message and backend just switches back to - // processing mode after sending CopyDone or ErrorResponse. - } - - FeMessage::Sync => { - self.write_message_noflush(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Terminate => { - return Ok(ProcessMsgResult::Break); - } - - // We prefer explicit pattern matching to wildcards, because - // this helps us spot the places where new variants are missing - FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message type: {:?}", - msg - ))); - } - } - - Ok(ProcessMsgResult::Continue) - } -} - -/// -/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData -/// messages. -/// - -pub struct CopyDataWriter<'a> { - pgb: &'a mut PostgresBackend, -} - -impl<'a> AsyncWrite for CopyDataWriter<'a> { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.get_mut(); - - // It's not strictly required to flush between each message, but makes it easier - // to view in wireshark, and usually the messages that the callers write are - // decently-sized anyway. - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } - - // CopyData - // XXX: if the input is large, we should split it into multiple messages. - // Not sure what the threshold should be, but the ultimate hard limit is that - // the length cannot exceed u32. - this.pgb.write_message_noflush(&BeMessage::CopyData(buf))?; - - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.get_mut(); - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } - this.pgb.poll_flush(cx) - } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.get_mut(); - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } - this.pgb.poll_flush(cx) - } -} - -pub fn short_error(e: &QueryError) -> String { - match e { - QueryError::Disconnected(connection_error) => connection_error.to_string(), - QueryError::Other(e) => format!("{e:#}"), - } -} - -pub(super) fn log_query_error(query: &str, e: &QueryError) { - match e { - QueryError::Disconnected(ConnectionError::Socket(io_error)) => { - if is_expected_io_error(io_error) { - info!("query handler for '{query}' failed with expected io error: {io_error}"); - } else { - error!("query handler for '{query}' failed with io error: {io_error}"); - } - } - QueryError::Disconnected(other_connection_error) => { - error!("query handler for '{query}' failed with connection error: {other_connection_error:?}") - } - QueryError::Other(e) => { - error!("query handler for '{query}' failed: {e:?}"); - } - } -} diff --git a/libs/utils/src/sock_split.rs b/libs/utils/src/sock_split.rs deleted file mode 100644 index b0e5a0bf6a..0000000000 --- a/libs/utils/src/sock_split.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{ - io::{self, BufReader, Write}, - net::{Shutdown, TcpStream}, - sync::Arc, -}; - -use rustls::Connection; - -/// Wrapper supporting reads of a shared TcpStream. -pub struct ArcTcpRead(Arc); - -impl io::Read for ArcTcpRead { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - (&*self.0).read(buf) - } -} - -impl std::ops::Deref for ArcTcpRead { - type Target = TcpStream; - - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -/// Wrapper around a TCP Stream supporting buffered reads. -pub struct BufStream(BufReader); - -impl io::Read for BufStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } -} - -impl io::Write for BufStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.get_ref().write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.get_ref().flush() - } -} - -impl BufStream { - /// Unwrap into the internal BufReader. - fn into_reader(self) -> BufReader { - self.0 - } - - /// Returns a reference to the underlying TcpStream. - fn get_ref(&self) -> &TcpStream { - &self.0.get_ref().0 - } -} - -pub enum ReadStream { - Tcp(BufReader), - Tls(rustls_split::ReadHalf), -} - -impl io::Read for ReadStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - Self::Tcp(reader) => reader.read(buf), - Self::Tls(read_half) => read_half.read(buf), - } - } -} - -impl ReadStream { - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.get_ref().shutdown(how), - Self::Tls(write_half) => write_half.shutdown(how), - } - } -} - -pub enum WriteStream { - Tcp(Arc), - Tls(rustls_split::WriteHalf), -} - -impl WriteStream { - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.shutdown(how), - Self::Tls(write_half) => write_half.shutdown(how), - } - } -} - -impl io::Write for WriteStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.as_ref().write(buf), - Self::Tls(write_half) => write_half.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.as_ref().flush(), - Self::Tls(write_half) => write_half.flush(), - } - } -} - -type TlsStream = rustls::StreamOwned; - -pub enum BidiStream { - Tcp(BufStream), - /// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`]. - Tls(Box>), -} - -impl BidiStream { - pub fn from_tcp(stream: TcpStream) -> Self { - Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream))))) - } - - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.get_ref().shutdown(how), - Self::Tls(tls_boxed) => { - if how == Shutdown::Read { - tls_boxed.sock.get_ref().shutdown(how) - } else { - tls_boxed.conn.send_close_notify(); - let res = tls_boxed.flush(); - tls_boxed.sock.get_ref().shutdown(how)?; - res - } - } - } - } - - /// Split the bi-directional stream into two owned read and write halves. - pub fn split(self) -> (ReadStream, WriteStream) { - match self { - Self::Tcp(stream) => { - let reader = stream.into_reader(); - let stream: Arc = reader.get_ref().0.clone(); - - (ReadStream::Tcp(reader), WriteStream::Tcp(stream)) - } - Self::Tls(tls_boxed) => { - let reader = tls_boxed.sock.into_reader(); - let buffer_data = reader.buffer().to_owned(); - let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192); - let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192); - - // TODO would be nice to avoid the Arc here - let socket = Arc::try_unwrap(reader.into_inner().0).unwrap(); - - let (read_half, write_half) = rustls_split::split( - socket, - Connection::Server(tls_boxed.conn), - read_buf_cfg, - write_buf_cfg, - ); - (ReadStream::Tls(read_half), WriteStream::Tls(write_half)) - } - } - } - - pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result { - match self { - Self::Tcp(mut stream) => { - conn.complete_io(&mut stream)?; - assert!(!conn.is_handshaking()); - Ok(Self::Tls(Box::new(TlsStream::new(conn, stream)))) - } - Self::Tls { .. } => Err(io::Error::new( - io::ErrorKind::InvalidInput, - "TLS is already started on this stream", - )), - } - } -} - -impl io::Read for BidiStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.read(buf), - Self::Tls(tls_boxed) => tls_boxed.read(buf), - } - } -} - -impl io::Write for BidiStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.write(buf), - Self::Tls(tls_boxed) => tls_boxed.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.flush(), - Self::Tls(tls_boxed) => tls_boxed.flush(), - } - } -} diff --git a/libs/utils/tests/ssl_test.rs b/libs/utils/tests/ssl_test.rs deleted file mode 100644 index fae707f049..0000000000 --- a/libs/utils/tests/ssl_test.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::{ - collections::HashMap, - io::{Cursor, Read, Write}, - net::{TcpListener, TcpStream}, - sync::Arc, -}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use once_cell::sync::Lazy; - -use utils::{ - postgres_backend::{AuthType, Handler, PostgresBackend}, - postgres_backend_async::QueryError, -}; - -fn make_tcp_pair() -> (TcpStream, TcpStream) { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - let client_stream = TcpStream::connect(addr).unwrap(); - let (server_stream, _) = listener.accept().unwrap(); - (server_stream, client_stream) -} - -static KEY: Lazy = Lazy::new(|| { - let mut cursor = Cursor::new(include_bytes!("key.pem")); - rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()) -}); - -static CERT: Lazy = Lazy::new(|| { - let mut cursor = Cursor::new(include_bytes!("cert.pem")); - rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) -}); - -#[test] -// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274), -// we resize the vector so doing some modifications after all -#[allow(clippy::read_zero_byte_vec)] -fn ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - const QUERY: &str = "hello world"; - - let client_jh = std::thread::spawn(move || { - // SSLRequest - client_sock.write_u32::(8).unwrap(); - client_sock.write_u32::(80877103).unwrap(); - - let ssl_response = client_sock.read_u8().unwrap(); - assert_eq!(b'S', ssl_response); - - let cfg = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates({ - let mut store = rustls::RootCertStore::empty(); - store.add(&CERT).unwrap(); - store - }) - .with_no_client_auth(); - let client_config = Arc::new(cfg); - - let dns_name = "localhost".try_into().unwrap(); - let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap(); - - conn.complete_io(&mut client_sock).unwrap(); - assert!(!conn.is_handshaking()); - - let mut stream = rustls::Stream::new(&mut conn, &mut client_sock); - - // StartupMessage - stream.write_u32::(9).unwrap(); - stream.write_u32::(196608).unwrap(); - stream.write_u8(0).unwrap(); - stream.flush().unwrap(); - - // wait for ReadyForQuery - let mut msg_buf = Vec::new(); - loop { - let msg = stream.read_u8().unwrap(); - let size = stream.read_u32::().unwrap() - 4; - msg_buf.resize(size as usize, 0); - stream.read_exact(&mut msg_buf).unwrap(); - - if msg == b'Z' { - // ReadyForQuery - break; - } - } - - // Query - stream.write_u8(b'Q').unwrap(); - stream - .write_u32::(4u32 + QUERY.len() as u32) - .unwrap(); - stream.write_all(QUERY.as_ref()).unwrap(); - stream.flush().unwrap(); - - // ReadyForQuery - let msg = stream.read_u8().unwrap(); - assert_eq!(msg, b'Z'); - }); - - struct TestHandler { - got_query: bool, - } - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError> { - self.got_query = query_string == QUERY; - Ok(()) - } - } - let mut handler = TestHandler { got_query: false }; - - let cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![CERT.clone()], KEY.clone()) - .unwrap(); - let tls_config = Some(Arc::new(cfg)); - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap(); - pgb.run(&mut handler).unwrap(); - assert!(handler.got_query); - - client_jh.join().unwrap(); - - // TODO consider shutdown behavior -} - -#[test] -fn no_ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - let client_jh = std::thread::spawn(move || { - let mut buf = BytesMut::new(); - - // SSLRequest - buf.put_u32(8); - buf.put_u32(80877103); - client_sock.write_all(&buf).unwrap(); - buf.clear(); - - let ssl_response = client_sock.read_u8().unwrap(); - assert_eq!(b'N', ssl_response); - }); - - struct TestHandler; - - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - _query_string: &str, - ) -> Result<(), QueryError> { - panic!() - } - } - - let mut handler = TestHandler; - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap(); - pgb.run(&mut handler).unwrap(); - - client_jh.join().unwrap(); -} - -#[test] -fn server_forces_ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - let client_jh = std::thread::spawn(move || { - // StartupMessage - client_sock.write_u32::(9).unwrap(); - client_sock.write_u32::(196608).unwrap(); - client_sock.write_u8(0).unwrap(); - client_sock.flush().unwrap(); - - // ErrorResponse - assert_eq!(client_sock.read_u8().unwrap(), b'E'); - let len = client_sock.read_u32::().unwrap() - 4; - - let mut body = vec![0; len as usize]; - client_sock.read_exact(&mut body).unwrap(); - let mut body = Bytes::from(body); - - let mut errors = HashMap::new(); - loop { - let field_type = body.get_u8(); - if field_type == 0u8 { - break; - } - - let end_idx = body.iter().position(|&b| b == 0u8).unwrap(); - let mut value = body.split_to(end_idx + 1); - assert_eq!(value[end_idx], 0u8); - value.truncate(end_idx); - let old = errors.insert(field_type, value); - assert!(old.is_none()); - } - - assert!(!body.has_remaining()); - - assert_eq!("must connect with TLS", errors.get(&b'M').unwrap()); - - // TODO read failure - }); - - struct TestHandler; - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - _query_string: &str, - ) -> Result<(), QueryError> { - panic!() - } - } - let mut handler = TestHandler; - - let cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![CERT.clone()], KEY.clone()) - .unwrap(); - let tls_config = Some(Arc::new(cfg)); - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap(); - let res = pgb.run(&mut handler).unwrap_err(); - assert_eq!("client did not connect with TLS", format!("{}", res)); - - client_jh.join().unwrap(); - - // TODO consider shutdown behavior -} diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index d2f0b84863..8d6641a387 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -37,6 +37,7 @@ num-traits.workspace = true once_cell.workspace = true pin-project-lite.workspace = true postgres.workspace = true +postgres_backend.workspace = true postgres-protocol.workspace = true postgres-types.workspace = true rand.workspace = true diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 01a2c85d74..0441760eef 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -23,11 +23,10 @@ use pageserver::{ tenant::mgr, virtual_file, }; +use postgres_backend::AuthType; use utils::{ auth::JwtAuth, - logging, - postgres_backend::AuthType, - project_git_version, + logging, project_git_version, sentry_init::init_sentry, signals::{self, Signal}, tcp_listener, diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 309e5367a4..b3525c2cc6 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -21,10 +21,10 @@ use std::time::Duration; use toml_edit; use toml_edit::{Document, Item}; +use postgres_backend::AuthType; use utils::{ id::{NodeId, TenantId, TimelineId}, logging::LogFormat, - postgres_backend::AuthType, }; use crate::tenant::config::TenantConf; diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index b362e25424..40e11a70b7 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -20,7 +20,8 @@ use pageserver_api::models::{ PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamNblocksRequest, PagestreamNblocksResponse, }; -use pq_proto::ConnectionError; +use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError}; +use pq_proto::framed::ConnectionError; use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; use std::io; @@ -35,8 +36,6 @@ use utils::{ auth::{Claims, JwtAuth, Scope}, id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, - postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError}, simple_rcu::RcuReadGuard, }; @@ -68,7 +67,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { msg } + msg = pgb.read_message() => { msg.map_err(QueryError::from)} }; match msg { @@ -79,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream continue, FeMessage::Terminate => { let msg = "client terminated connection with Terminate message during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; break; } m => { let msg = format!("unexpected message {m:?}"); - pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None))?; + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?; Err(io::Error::new(io::ErrorKind::Other, msg))?; break; } @@ -96,16 +97,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { let msg = "client closed connection during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; pgb.flush().await?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { Err(io_error)?; } Err(other) => { - Err(io::Error::new(io::ErrorKind::Other, other))?; + Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?; } }; } @@ -212,7 +214,7 @@ async fn page_service_conn_main( // we've been requested to shut down Ok(()) } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { if is_expected_io_error(&io_error) { info!("Postgres client disconnected ({io_error})"); Ok(()) @@ -721,7 +723,7 @@ impl PageServerHandler { } #[async_trait::async_trait] -impl postgres_backend_async::Handler for PageServerHandler { +impl postgres_backend::Handler for PageServerHandler { fn check_auth_jwt( &mut self, _pgb: &mut PostgresBackend, @@ -1055,7 +1057,7 @@ impl From for QueryError { fn from(e: GetActiveTenantError) -> Self { match e { GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( - ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), + ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), ), GetActiveTenantError::Other(e) => QueryError::Other(e), } diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 7e06c398af..7194a4f3ed 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -33,10 +33,11 @@ use crate::{ walingest::WalIngest, walrecord::DecodedWALRecord, }; +use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::waldecoder::WalStreamDecoder; use pq_proto::ReplicationFeedback; -use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error}; +use utils::lsn::Lsn; /// Status of the connection. #[derive(Debug, Clone, Copy)] @@ -353,7 +354,7 @@ pub async fn handle_walreceiver_connection( debug!("neon_status_update {status_update:?}"); let mut data = BytesMut::new(); - status_update.serialize(&mut data)?; + status_update.serialize(&mut data); physical_stream .as_mut() .zenith_status_update(data.len() as u64, &data) @@ -434,8 +435,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result> = Lazy::new(Default::default); @@ -33,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N /// Console management API listener task. /// It spawns console response handlers needed for the link auth. -pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> { +pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { scopeguard::defer! { info!("mgmt has shut down"); } @@ -42,18 +40,12 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> let (socket, peer_addr) = listener.accept().await?; info!("accepted connection from {peer_addr}"); - let socket = socket.into_std()?; socket .set_nodelay(true) .context("failed to set client socket option")?; - socket - .set_nonblocking(false) - .context("failed to set client socket option")?; - // TODO: replace with async tasks. - thread::spawn(move || { - let tid = std::thread::current().id(); - let span = info_span!("mgmt", thread = format_args!("{tid:?}")); + tokio::task::spawn(async move { + let span = info_span!("mgmt", peer = %peer_addr); let _enter = span.enter(); info!("started a new console management API thread"); @@ -61,16 +53,16 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> info!("console management API thread is about to finish"); } - if let Err(e) = handle_connection(socket) { + if let Err(e) = handle_connection(socket).await { error!("thread failed with an error: {e}"); } }); } } -fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { - let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?; - pgbackend.run(&mut MgmtHandler) +async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { + let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?; + pgbackend.run(&mut MgmtHandler, future::pending::<()>).await } /// A message received by `mgmt` when a compute node is ready. @@ -78,16 +70,21 @@ pub type ComputeReady = Result; // TODO: replace with an http-based protocol. struct MgmtHandler; +#[async_trait::async_trait] impl postgres_backend::Handler for MgmtHandler { - fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { - try_process_query(pgb, query).map_err(|e| { + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query: &str, + ) -> Result<(), QueryError> { + try_process_query(pgb, query).await.map_err(|e| { error!("failed to process response: {e:?}"); e }) } } -fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { +async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?; let span = info_span!("event", session_id = resp.session_id); @@ -98,11 +95,11 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? - .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; } Err(e) => { error!("failed to deliver response to per-client task"); - pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?; + pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 02a0fabe9a..5a802dafb2 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,45 +1,40 @@ use crate::error::UserFacingError; use anyhow::bail; -use bytes::BytesMut; use pin_project_lite::pin_project; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket}; +use pq_proto::framed::{ConnectionError, Framed}; +use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; use std::pin::Pin; use std::sync::Arc; use std::{io, task}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; -pin_project! { - /// Stream wrapper which implements libpq's protocol. - /// NOTE: This object deliberately doesn't implement [`AsyncRead`] - /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying - /// to pass random malformed bytes through the connection). - pub struct PqStream { - #[pin] - stream: S, - buffer: BytesMut, - } +/// Stream wrapper which implements libpq's protocol. +/// NOTE: This object deliberately doesn't implement [`AsyncRead`] +/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying +/// to pass random malformed bytes through the connection). +pub struct PqStream { + framed: Framed, } impl PqStream { /// Construct a new libpq protocol wrapper. pub fn new(stream: S) -> Self { Self { - stream, - buffer: Default::default(), + framed: Framed::new(stream), } } /// Extract the underlying stream. pub fn into_inner(self) -> S { - self.stream + self.framed.into_inner() } /// Get a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { - &self.stream + self.framed.get_ref() } } @@ -50,16 +45,19 @@ fn err_connection() -> io::Error { impl PqStream { /// Receive [`FeStartupPacket`], which is a first packet sent by a client. pub async fn read_startup_packet(&mut self) -> io::Result { - // TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket` - let msg = FeStartupPacket::read_fut(&mut self.stream) + self.framed + .read_startup_message() .await .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection)?; + .ok_or_else(err_connection) + } - match msg { - FeMessage::StartupPacket(packet) => Ok(packet), - _ => panic!("unreachable state"), - } + async fn read_message(&mut self) -> io::Result { + self.framed + .read_message() + .await + .map_err(ConnectionError::into_io_error)? + .ok_or_else(err_connection) } pub async fn read_password_message(&mut self) -> io::Result { @@ -71,19 +69,14 @@ impl PqStream { )), } } - - async fn read_message(&mut self) -> io::Result { - FeMessage::read_fut(&mut self.stream) - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } } impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buffer, message)?; + self.framed + .write_message(message) + .map_err(ProtocolError::into_io_error)?; Ok(self) } @@ -96,9 +89,7 @@ impl PqStream { /// Flush the output buffer into the underlying stream. pub async fn flush(&mut self) -> io::Result<&mut Self> { - self.stream.write_all(&self.buffer).await?; - self.buffer.clear(); - self.stream.flush().await?; + self.framed.flush().await?; Ok(self) } diff --git a/run_clippy.sh b/run_clippy.sh index fe0e745d7d..0558541089 100755 --- a/run_clippy.sh +++ b/run_clippy.sh @@ -11,12 +11,18 @@ # Not every feature is supported in macOS builds. Avoid running regular linting # script that checks every feature. +# +# manual-range-contains wants +# !(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len) +# instead of +# len < 4 || len > MAX_STARTUP_PACKET_LENGTH +# , let's disagree. if [[ "$OSTYPE" == "darwin"* ]]; then # no extra features to test currently, add more here when needed - cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -D warnings + cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -A clippy::manual-range-contains -D warnings else # * `-A unknown_lints` – do not warn about unknown lint suppressions # that people with newer toolchains might use # * `-D warnings` - fail on any warnings (`cargo` returns non-zero exit status) - cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -D warnings + cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -A clippy::manual-range-contains -D warnings fi diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 4ee8d82203..88f950d6c8 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -35,6 +35,7 @@ toml_edit.workspace = true tracing.workspace = true url.workspace = true metrics.workspace = true +postgres_backend.workspace = true postgres_ffi.workspace = true pq_proto.workspace = true remote_storage.workspace = true diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 683050e9cd..d2cb9f79b9 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -236,7 +236,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { let conf_cloned = conf.clone(); let safekeeper_thread = thread::Builder::new() - .name("safekeeper thread".into()) + .name("WAL service thread".into()) .spawn(|| wal_service::thread_main(conf_cloned, pg_listener)) .unwrap(); diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index 60df5dd372..be1c89c97b 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -1,27 +1,23 @@ //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres //! protocol commands. +use anyhow::Context; +use std::str; +use tracing::{info, info_span, Instrument}; + use crate::auth::check_permission; use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage}; -use crate::receive_wal::ReceiveWalConn; - -use crate::send_wal::ReplicationConn; use crate::{GlobalTimelines, SafeKeeperConf}; -use anyhow::Context; - +use postgres_backend::QueryError; +use postgres_backend::{self, PostgresBackend}; use postgres_ffi::PG_TLI; -use regex::Regex; - use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID}; -use std::str; -use tracing::info; +use regex::Regex; use utils::auth::{Claims, Scope}; -use utils::postgres_backend_async::QueryError; use utils::{ id::{TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, - postgres_backend::{self, PostgresBackend}, }; /// Safekeeper handler of postgres commands @@ -53,7 +49,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { let start_lsn = caps .next() .map(|cap| cap[1].parse::()) - .context("failed to parse start LSN from START_REPLICATION command")??; + .context("parse start LSN from START_REPLICATION command")??; Ok(SafekeeperPostgresCommand::StartReplication { start_lsn }) } else if cmd.starts_with("IDENTIFY_SYSTEM") { Ok(SafekeeperPostgresCommand::IdentifySystem) @@ -67,6 +63,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { } } +#[async_trait::async_trait] impl postgres_backend::Handler for SafekeeperPostgresHandler { // tenant_id and timeline_id are passed in connection string params fn startup( @@ -137,7 +134,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { Ok(()) } - fn process_query( + async fn process_query( &mut self, pgb: &mut PostgresBackend, query_string: &str, @@ -147,9 +144,10 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { .starts_with("set datestyle to ") { // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect - pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; return Ok(()); } + let cmd = parse_cmd(query_string)?; info!( @@ -161,26 +159,23 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { let timeline_id = self.timeline_id.context("timelineid is required")?; self.check_permission(Some(tenant_id))?; self.ttid = TenantTimelineId::new(tenant_id, timeline_id); + let span_ttid = self.ttid; // satisfy borrow checker - let res = match cmd { - SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self), + match cmd { + SafekeeperPostgresCommand::StartWalPush => { + self.handle_start_wal_push(pgb) + .instrument(info_span!("WAL receiver", ttid = %span_ttid)) + .await + } SafekeeperPostgresCommand::StartReplication { start_lsn } => { - ReplicationConn::new(pgb).run(self, pgb, start_lsn) + self.handle_start_replication(pgb, start_lsn) + .instrument(info_span!("WAL sender", ttid = %span_ttid)) + .await } - SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb), - SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd), - }; - - match res { - Ok(()) => Ok(()), - Err(QueryError::Disconnected(connection_error)) => { - info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}"); - Err(QueryError::Disconnected(connection_error)) + SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await, + SafekeeperPostgresCommand::JSONCtrl { ref cmd } => { + handle_json_ctrl(self, pgb, cmd).await } - Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!( - "Failed to process query for timeline {}", - self.ttid - )))), } } } @@ -217,7 +212,10 @@ impl SafekeeperPostgresHandler { /// /// Handle IDENTIFY_SYSTEM replication command /// - fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> { + async fn handle_identify_system( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), QueryError> { let tli = GlobalTimelines::get(self.ttid)?; let lsn = if self.is_walproposer_recovery() { @@ -267,7 +265,7 @@ impl SafekeeperPostgresHandler { Some(lsn_bytes), None, ]))? - .write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; + .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; Ok(()) } diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index a917d61678..7e5d7d1b47 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -8,6 +8,7 @@ use serde::Serialize; use serde::Serializer; use std::collections::{HashMap, HashSet}; use std::fmt::Display; + use std::sync::Arc; use storage_broker::proto::SafekeeperTimelineInfo; use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId; @@ -181,12 +182,9 @@ async fn timeline_create_handler(mut request: Request) -> Result anyhow::Result> { +async fn prepare_safekeeper( + ttid: TenantTimelineId, + pg_version: u32, +) -> anyhow::Result> { GlobalTimelines::create( ttid, ServerInfo { @@ -106,6 +110,7 @@ fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result Lsn::INVALID, Lsn::INVALID, ) + .await } fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> anyhow::Result<()> { @@ -128,15 +133,15 @@ fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> anyhow::R } #[derive(Debug, Serialize, Deserialize)] -struct InsertedWAL { +pub struct InsertedWAL { begin_lsn: Lsn, - end_lsn: Lsn, + pub end_lsn: Lsn, append_response: AppendResponse, } /// Extend local WAL with new LogicalMessage record. To do that, /// create AppendRequest with new WAL and pass it to safekeeper. -fn append_logical_message( +pub fn append_logical_message( tli: &Arc, msg: &AppendLogicalMessage, ) -> anyhow::Result { diff --git a/safekeeper/src/lib.rs b/safekeeper/src/lib.rs index 891d73533f..818f2f9424 100644 --- a/safekeeper/src/lib.rs +++ b/safekeeper/src/lib.rs @@ -1,8 +1,7 @@ -use storage_broker::Uri; -// use remote_storage::RemoteStorageConfig; use std::path::PathBuf; use std::time::Duration; +use storage_broker::Uri; use utils::id::{NodeId, TenantId, TenantTimelineId}; diff --git a/safekeeper/src/receive_wal.rs b/safekeeper/src/receive_wal.rs index 671e5470a0..22c9871026 100644 --- a/safekeeper/src/receive_wal.rs +++ b/safekeeper/src/receive_wal.rs @@ -2,204 +2,284 @@ //! Gets messages from the network, passes them down to consensus module and //! sends replies back. -use anyhow::anyhow; -use anyhow::Context; - -use bytes::BytesMut; -use tracing::*; -use utils::lsn::Lsn; -use utils::postgres_backend_async::QueryError; - +use crate::handler::SafekeeperPostgresHandler; +use crate::safekeeper::AcceptorProposerMessage; +use crate::safekeeper::ProposerAcceptorMessage; use crate::safekeeper::ServerInfo; use crate::timeline::Timeline; use crate::GlobalTimelines; - +use anyhow::{anyhow, Context}; +use bytes::BytesMut; +use nix::unistd::gettid; +use postgres_backend::CopyStreamHandlerEnd; +use postgres_backend::PostgresBackend; +use postgres_backend::PostgresBackendReader; +use postgres_backend::QueryError; +use pq_proto::BeMessage; use std::net::SocketAddr; -use std::sync::mpsc::channel; -use std::sync::mpsc::Receiver; - use std::sync::Arc; use std::thread; +use std::thread::JoinHandle; +use tokio::sync::mpsc::channel; +use tokio::sync::mpsc::error::TryRecvError; +use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::Sender; +use tokio::task::spawn_blocking; +use tracing::*; +use utils::id::TenantTimelineId; +use utils::lsn::Lsn; -use crate::safekeeper::AcceptorProposerMessage; -use crate::safekeeper::ProposerAcceptorMessage; +const MSG_QUEUE_SIZE: usize = 256; +const REPLY_QUEUE_SIZE: usize = 16; -use crate::handler::SafekeeperPostgresHandler; -use pq_proto::{BeMessage, FeMessage}; -use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream}; - -pub struct ReceiveWalConn<'pg> { - /// Postgres connection - pg_backend: &'pg mut PostgresBackend, - /// The cached result of `pg_backend.socket().peer_addr()` (roughly) - peer_addr: SocketAddr, -} - -impl<'pg> ReceiveWalConn<'pg> { - pub fn new(pg: &'pg mut PostgresBackend) -> ReceiveWalConn<'pg> { - let peer_addr = *pg.get_peer_addr(); - ReceiveWalConn { - pg_backend: pg, - peer_addr, +impl SafekeeperPostgresHandler { + /// Wrapper around handle_start_wal_push_guts handling result. Error is + /// handled here while we're still in walreceiver ttid span; with API + /// extension, this can probably be moved into postgres_backend. + pub async fn handle_start_wal_push( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), QueryError> { + if let Err(end) = self.handle_start_wal_push_guts(pgb).await { + // Log the result and probably send it to the client, closing the stream. + pgb.handle_copy_stream_end(end).await; } - } - - // Send message to the postgres - fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> { - let mut buf = BytesMut::with_capacity(128); - msg.serialize(&mut buf)?; - self.pg_backend.write_message(&BeMessage::CopyData(&buf))?; Ok(()) } - /// Receive WAL from wal_proposer - pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> { - let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered(); - + pub async fn handle_start_wal_push_guts( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), CopyStreamHandlerEnd> { // Notify the libpq client that it's allowed to send `CopyData` messages - self.pg_backend - .write_message(&BeMessage::CopyBothResponse)?; + pgb.write_message(&BeMessage::CopyBothResponse).await?; - let r = self - .pg_backend - .take_stream_in() - .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?; - let mut poll_reader = ProposerPollStream::new(r)?; + // Experiments [1] confirm that doing network IO in one (this) thread and + // processing with disc IO in another significantly improves + // performance; we spawn off WalAcceptor thread for message processing + // to this end. + // + // [1] https://github.com/neondatabase/neon/pull/1318 + let (msg_tx, msg_rx) = channel(MSG_QUEUE_SIZE); + let (reply_tx, reply_rx) = channel(REPLY_QUEUE_SIZE); + let mut acceptor_handle: Option>> = None; - // Receive information about server - let next_msg = poll_reader.recv_msg()?; - let tli = match next_msg { - ProposerAcceptorMessage::Greeting(ref greeting) => { - info!( - "start handshake with walproposer {} sysid {} timeline {}", - self.peer_addr, greeting.system_id, greeting.tli, - ); - let server_info = ServerInfo { - pg_version: greeting.pg_version, - system_id: greeting.system_id, - wal_seg_size: greeting.wal_seg_size, - }; - GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)? - } - _ => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message {next_msg:?} instead of greeting" - ))) - } + // Concurrently receive and send data; replies are not synchronized with + // sends, so this avoids deadlocks. + let mut pgb_reader = pgb.split().context("START_WAL_PUSH split")?; + let peer_addr = *pgb.get_peer_addr(); + let res = tokio::select! { + // todo: add read|write .context to these errors + r = read_network(self.ttid, &mut pgb_reader, peer_addr, msg_tx, &mut acceptor_handle, msg_rx, reply_tx) => r, + r = write_network(pgb, reply_rx) => r, }; - let mut next_msg = Some(next_msg); + // Join pg backend back. + pgb.unsplit(pgb_reader)?; - let mut first_time_through = true; - let mut _guard: Option = None; - loop { - if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) { - // poll AppendRequest's without blocking and write WAL to disk without flushing, - // while it's readily available - while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg { - let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); - - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - - next_msg = poll_reader.poll_msg(); - } - - // flush all written WAL to the disk - let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } else if let Some(msg) = next_msg.take() { - // process other message - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } - if first_time_through { - // Register the connection and defer unregister. Do that only - // after processing first message, as it sets wal_seg_size, - // wanted by many. - tli.on_compute_connect()?; - _guard = Some(ComputeConnectionGuard { - timeline: Arc::clone(&tli), - }); - first_time_through = false; + // Join the spawned WalAcceptor. At this point chans to/from it passed + // to network routines are dropped, so it will exit as soon as it + // touches them. + match acceptor_handle { + None => { + // failed even before spawning; read_network should have error + Err(res.expect_err("no error with WalAcceptor not spawn")) } + Some(handle) => { + let wal_acceptor_res = handle.join(); - // blocking wait for the next message - if next_msg.is_none() { - next_msg = Some(poll_reader.recv_msg()?); + // If there was any network error, return it. + res?; + + // Otherwise, WalAcceptor thread must have errored. + match wal_acceptor_res { + Ok(Ok(_)) => Ok(()), // can't happen currently; would be if we add graceful termination + Ok(Err(e)) => Err(CopyStreamHandlerEnd::Other(e.context("WAL acceptor"))), + Err(_) => Err(CopyStreamHandlerEnd::Other(anyhow!( + "WalAcceptor thread panicked", + ))), + } } } } } -struct ProposerPollStream { - msg_rx: Receiver, - read_thread: Option>>, +/// Read next message from walproposer. +/// TODO: Return Ok(None) on graceful termination. +async fn read_message( + pgb_reader: &mut PostgresBackendReader, +) -> Result { + let copy_data = pgb_reader.read_copy_message().await?; + let msg = ProposerAcceptorMessage::parse(copy_data)?; + Ok(msg) } -impl ProposerPollStream { - fn new(mut r: ReadStream) -> anyhow::Result { - let (msg_tx, msg_rx) = channel(); - - let read_thread = thread::Builder::new() - .name("Read WAL thread".into()) - .spawn(move || -> Result<(), QueryError> { - loop { - let copy_data = match FeMessage::read(&mut r)? { - Some(FeMessage::CopyData(bytes)) => Ok(bytes), - Some(msg) => Err(QueryError::Other(anyhow::anyhow!( - "expected `CopyData` message, found {msg:?}" - ))), - None => Err(QueryError::from(std::io::Error::new( - std::io::ErrorKind::ConnectionAborted, - "walproposer closed the connection", - ))), - }?; - - let msg = ProposerAcceptorMessage::parse(copy_data)?; - msg_tx - .send(msg) - .context("Failed to send the proposer message")?; - } - // msg_tx will be dropped here, this will also close msg_rx - })?; - - Ok(Self { - msg_rx, - read_thread: Some(read_thread), - }) - } - - fn recv_msg(&mut self) -> Result { - self.msg_rx.recv().map_err(|_| { - // return error from the read thread - let res = match self.read_thread.take() { - Some(thread) => thread.join(), - None => return QueryError::Other(anyhow::anyhow!("read thread is gone")), +/// Read messages from socket and pass it to WalAcceptor thread. Returns Ok(()) +/// if msg_tx closed; it must mean WalAcceptor terminated, joining it should +/// tell the error. +async fn read_network( + ttid: TenantTimelineId, + pgb_reader: &mut PostgresBackendReader, + peer_addr: SocketAddr, + msg_tx: Sender, + // WalAcceptor is spawned when we learn server info from walproposer and + // create timeline; handle is put here. + acceptor_handle: &mut Option>>, + msg_rx: Receiver, + reply_tx: Sender, +) -> Result<(), CopyStreamHandlerEnd> { + // Receive information about server to create timeline, if not yet. + let next_msg = read_message(pgb_reader).await?; + let tli = match next_msg { + ProposerAcceptorMessage::Greeting(ref greeting) => { + info!( + "start handshake with walproposer {} sysid {} timeline {}", + peer_addr, greeting.system_id, greeting.tli, + ); + let server_info = ServerInfo { + pg_version: greeting.pg_version, + system_id: greeting.system_id, + wal_seg_size: greeting.wal_seg_size, }; + GlobalTimelines::create(ttid, server_info, Lsn::INVALID, Lsn::INVALID).await? + } + _ => { + return Err(CopyStreamHandlerEnd::Other(anyhow::anyhow!( + "unexpected message {next_msg:?} instead of greeting" + ))) + } + }; - match res { - Ok(Ok(())) => { - QueryError::Other(anyhow::anyhow!("unexpected result from read thread")) - } - Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")), - Ok(Err(err)) => err, + *acceptor_handle = Some( + WalAcceptor::spawn(tli.clone(), msg_rx, reply_tx).context("spawn WalAcceptor thread")?, + ); + + // Forward all messages to WalAcceptor + read_network_loop(pgb_reader, msg_tx, next_msg).await +} + +async fn read_network_loop( + pgb_reader: &mut PostgresBackendReader, + msg_tx: Sender, + mut next_msg: ProposerAcceptorMessage, +) -> Result<(), CopyStreamHandlerEnd> { + loop { + if msg_tx.send(next_msg).await.is_err() { + return Ok(()); // chan closed, WalAcceptor terminated + } + next_msg = read_message(pgb_reader).await?; + } +} + +/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(()) +/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should +/// tell the error. +async fn write_network( + pgb_writer: &mut PostgresBackend, + mut reply_rx: Receiver, +) -> Result<(), CopyStreamHandlerEnd> { + let mut buf = BytesMut::with_capacity(128); + + loop { + match reply_rx.recv().await { + Some(msg) => { + buf.clear(); + msg.serialize(&mut buf)?; + pgb_writer.write_message(&BeMessage::CopyData(&buf)).await?; } - }) + None => return Ok(()), // chan closed, WalAcceptor terminated + } + } +} + +/// Takes messages from msg_rx, processes and pushes replies to reply_tx. +struct WalAcceptor { + tli: Arc, + msg_rx: Receiver, + reply_tx: Sender, +} + +impl WalAcceptor { + /// Spawn thread with WalAcceptor running, return handle to it. + fn spawn( + tli: Arc, + msg_rx: Receiver, + reply_tx: Sender, + ) -> anyhow::Result>> { + let thread_name = format!("WAL acceptor {}", tli.ttid); + thread::Builder::new() + .name(thread_name) + .spawn(move || -> anyhow::Result<()> { + let mut wa = WalAcceptor { + tli, + msg_rx, + reply_tx, + }; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + let span_ttid = wa.tli.ttid; // satisfy borrow checker + runtime.block_on( + wa.run() + .instrument(info_span!("WAL acceptor", tid = %gettid(), ttid = %span_ttid)), + ) + }) + .map_err(anyhow::Error::from) } - fn poll_msg(&mut self) -> Option { - let res = self.msg_rx.try_recv(); + /// The main loop. Returns Ok(()) if either msg_rx or reply_tx got closed; + /// it must mean that network thread terminated. + async fn run(&mut self) -> anyhow::Result<()> { + // Register the connection and defer unregister. + self.tli.on_compute_connect().await?; + let _guard = ComputeConnectionGuard { + timeline: Arc::clone(&self.tli), + }; - match res { - Err(_) => None, - Ok(msg) => Some(msg), + let mut next_msg: ProposerAcceptorMessage; + + loop { + let opt_msg = self.msg_rx.recv().await; + if opt_msg.is_none() { + return Ok(()); // chan closed, streaming terminated + } + next_msg = opt_msg.unwrap(); + + if matches!(next_msg, ProposerAcceptorMessage::AppendRequest(_)) { + // loop through AppendRequest's while it's readily available to + // write as many WAL as possible without fsyncing + while let ProposerAcceptorMessage::AppendRequest(append_request) = next_msg { + let noflush_msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); + + if let Some(reply) = self.tli.process_msg(&noflush_msg)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + + match self.msg_rx.try_recv() { + Ok(msg) => next_msg = msg, + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => return Ok(()), // chan closed, streaming terminated + } + } + + // flush all written WAL to the disk + if let Some(reply) = self.tli.process_msg(&ProposerAcceptorMessage::FlushWAL)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + } else { + // process message other than AppendRequest + if let Some(reply) = self.tli.process_msg(&next_msg)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + } } } } @@ -210,8 +290,13 @@ struct ComputeConnectionGuard { impl Drop for ComputeConnectionGuard { fn drop(&mut self) { - if let Err(e) = self.timeline.on_compute_disconnect() { - error!("failed to unregister compute connection: {}", e); - } + let tli = self.timeline.clone(); + // tokio forbids to call blocking_send inside the runtime, and see + // comments in on_compute_disconnect why we call blocking_send. + spawn_blocking(move || { + if let Err(e) = tli.on_compute_disconnect() { + error!("failed to unregister compute connection: {}", e); + } + }); } } diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index fa973a3ede..2fcfce69f6 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -486,7 +486,7 @@ impl AcceptorProposerMessage { buf.put_u64_le(msg.hs_feedback.xmin); buf.put_u64_le(msg.hs_feedback.catalog_xmin); - msg.pageserver_feedback.serialize(buf)? + msg.pageserver_feedback.serialize(buf); } } diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 20600ab694..5c9f763b4a 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -5,24 +5,22 @@ use crate::handler::SafekeeperPostgresHandler; use crate::timeline::{ReplicaState, Timeline}; use crate::wal_storage::WalReader; use crate::GlobalTimelines; -use anyhow::Context; - +use anyhow::Context as AnyhowContext; use bytes::Bytes; +use postgres_backend::PostgresBackend; +use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError}; use postgres_ffi::get_current_timestamp; use postgres_ffi::{TimestampTz, MAX_SEND_SIZE}; +use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; use serde::{Deserialize, Serialize}; use std::cmp::min; -use std::net::Shutdown; +use std::str; use std::sync::Arc; use std::time::Duration; -use std::{io, str, thread}; -use utils::postgres_backend_async::QueryError; - -use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; use tokio::sync::watch::Receiver; use tokio::time::timeout; use tracing::*; -use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream}; +use utils::{bin_ser::BeSer, lsn::Lsn}; // See: https://www.postgresql.org/docs/13/protocol-replication.html const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h'; @@ -60,13 +58,6 @@ pub struct StandbyReply { pub reply_requested: bool, } -/// A network connection that's speaking the replication protocol. -pub struct ReplicationConn { - /// This is an `Option` because we will spawn a background thread that will - /// `take` it from us. - stream_in: Option, -} - /// Scope guard to unregister replication connection from timeline struct ReplicationConnGuard { replica: usize, // replica internal ID assigned by timeline @@ -79,230 +70,274 @@ impl Drop for ReplicationConnGuard { } } -impl ReplicationConn { - /// Create a new `ReplicationConn` - pub fn new(pgb: &mut PostgresBackend) -> Self { - Self { - stream_in: pgb.take_stream_in(), +impl SafekeeperPostgresHandler { + /// Wrapper around handle_start_replication_guts handling result. Error is + /// handled here while we're still in walsender ttid span; with API + /// extension, this can probably be moved into postgres_backend. + pub async fn handle_start_replication( + &mut self, + pgb: &mut PostgresBackend, + start_pos: Lsn, + ) -> Result<(), QueryError> { + if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await { + // Log the result and probably send it to the client, closing the stream. + pgb.handle_copy_stream_end(end).await; } - } - - /// Handle incoming messages from the network. - /// This is spawned into the background by `handle_start_replication`. - fn background_thread( - mut stream_in: ReadStream, - replica_guard: Arc, - ) -> anyhow::Result<()> { - let replica_id = replica_guard.replica; - let timeline = &replica_guard.timeline; - - let mut state = ReplicaState::new(); - // Wait for replica's feedback. - while let Some(msg) = FeMessage::read(&mut stream_in)? { - match &msg { - FeMessage::CopyData(m) => { - // There's three possible data messages that the client is supposed to send here: - // `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`. - - match m.first().cloned() { - Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { - // Note: deserializing is on m[1..] because we skip the tag byte. - state.hs_feedback = HotStandbyFeedback::des(&m[1..]) - .context("failed to deserialize HotStandbyFeedback")?; - timeline.update_replica_state(replica_id, state); - } - Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { - let _reply = StandbyReply::des(&m[1..]) - .context("failed to deserialize StandbyReply")?; - // This must be a regular postgres replica, - // because pageserver doesn't send this type of messages to safekeeper. - // Currently this is not implemented, so this message is ignored. - - warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet."); - // timeline.update_replica_state(replica_id, Some(state)); - } - Some(NEON_STATUS_UPDATE_TAG_BYTE) => { - // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. - let buf = Bytes::copy_from_slice(&m[9..]); - let reply = ReplicationFeedback::parse(buf); - - trace!("ReplicationFeedback is {:?}", reply); - // Only pageserver sends ReplicationFeedback, so set the flag. - // This replica is the source of information to resend to compute. - state.pageserver_feedback = Some(reply); - - timeline.update_replica_state(replica_id, state); - } - _ => warn!("unexpected message {:?}", msg), - } - } - FeMessage::Sync => {} - FeMessage::CopyFail => { - // Shutdown the connection, because rust-postgres client cannot be dropped - // when connection is alive. - let _ = stream_in.shutdown(Shutdown::Both); - anyhow::bail!("Copy failed"); - } - _ => { - // We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored. - info!("unexpected message {:?}", msg); - } - } - } - Ok(()) } - /// - /// Handle START_REPLICATION replication command - /// - pub fn run( + pub async fn handle_start_replication_guts( &mut self, - spg: &mut SafekeeperPostgresHandler, pgb: &mut PostgresBackend, - mut start_pos: Lsn, - ) -> Result<(), QueryError> { - let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered(); - - let tli = GlobalTimelines::get(spg.ttid)?; - - // spawn the background thread which receives HotStandbyFeedback messages. - let bg_timeline = Arc::clone(&tli); - let bg_stream_in = self.stream_in.take().unwrap(); - let bg_timeline_id = spg.timeline_id.unwrap(); + start_pos: Lsn, + ) -> Result<(), CopyStreamHandlerEnd> { + let appname = self.appname.clone(); + let tli = GlobalTimelines::get(self.ttid)?; let state = ReplicaState::new(); // This replica_id is used below to check if it's time to stop replication. - let replica_id = bg_timeline.add_replica(state); + let replica_id = tli.add_replica(state); // Use a guard object to remove our entry from the timeline, when the background // thread and us have both finished using it. - let replica_guard = Arc::new(ReplicationConnGuard { + let _guard = Arc::new(ReplicationConnGuard { replica: replica_id, - timeline: bg_timeline, + timeline: tli.clone(), }); - let bg_replica_guard = Arc::clone(&replica_guard); - // TODO: here we got two threads, one for writing WAL and one for receiving - // feedback. If one of them fails, we should shutdown the other one too. - let _ = thread::Builder::new() - .name("HotStandbyFeedback thread".into()) - .spawn(move || { - let _enter = - info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered(); - if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) { - error!("Replication background thread failed: {}", err); + // Walproposer gets special handling: safekeeper must give proposer all + // local WAL till the end, whether committed or not (walproposer will + // hang otherwise). That's because walproposer runs the consensus and + // synchronizes safekeepers on the most advanced one. + // + // There is a small risk of this WAL getting concurrently garbaged if + // another compute rises which collects majority and starts fixing log + // on this safekeeper itself. That's ok as (old) proposer will never be + // able to commit such WAL. + let stop_pos: Option = if self.is_walproposer_recovery() { + let wal_end = tli.get_flush_lsn(); + Some(wal_end) + } else { + None + }; + let end_pos = stop_pos.unwrap_or(Lsn::INVALID); + + info!( + "starting streaming from {:?} till {:?}", + start_pos, stop_pos + ); + + // switch to copy + pgb.write_message(&BeMessage::CopyBothResponse).await?; + + let (_, persisted_state) = tli.get_state(); + let wal_reader = WalReader::new( + self.conf.workdir.clone(), + self.conf.timeline_dir(&tli.ttid), + &persisted_state, + start_pos, + self.conf.wal_backup_enabled, + )?; + + // Split to concurrently receive and send data; replies are generally + // not synchronized with sends, so this avoids deadlocks. + let reader = pgb.split().context("START_REPLICATION split")?; + + let mut sender = WalSender { + pgb, + tli: tli.clone(), + appname, + start_pos, + end_pos, + stop_pos, + commit_lsn_watch_rx: tli.get_commit_lsn_watch_rx(), + replica_id, + wal_reader, + send_buf: [0; MAX_SEND_SIZE], + }; + let mut reply_reader = ReplyReader { + reader, + tli, + replica_id, + feedback: ReplicaState::new(), + }; + + let res = tokio::select! { + // todo: add read|write .context to these errors + r = sender.run() => r, + r = reply_reader.run() => r, + }; + // Join pg backend back. + pgb.unsplit(reply_reader.reader)?; + + res + } +} + +/// A half driving sending WAL. +struct WalSender<'a> { + pgb: &'a mut PostgresBackend, + tli: Arc, + appname: Option, + // Position since which we are sending next chunk. + start_pos: Lsn, + // WAL up to this position is known to be locally available. + end_pos: Lsn, + // If present, terminate after reaching this position; used by walproposer + // in recovery. + stop_pos: Option, + commit_lsn_watch_rx: Receiver, + replica_id: usize, + wal_reader: WalReader, + // buffer for readling WAL into to send it + send_buf: [u8; MAX_SEND_SIZE], +} + +impl WalSender<'_> { + /// Send WAL until + /// - an error occurs + /// - if we are streaming to walproposer, we've streamed until stop_pos + /// (recovery finished) + /// - receiver is caughtup and there is no computes + /// + /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? + /// convenience. + async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + // If we are streaming to walproposer, check it is time to stop. + if let Some(stop_pos) = self.stop_pos { + if self.start_pos >= stop_pos { + // recovery finished + return Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to walproposer at {}, recovery finished", + self.start_pos + ))); } - })?; - - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - runtime.block_on(async move { - let (inmem_state, persisted_state) = tli.get_state(); - // add persisted_state.timeline_start_lsn == Lsn(0) check - - // Walproposer gets special handling: safekeeper must give proposer all - // local WAL till the end, whether committed or not (walproposer will - // hang otherwise). That's because walproposer runs the consensus and - // synchronizes safekeepers on the most advanced one. - // - // There is a small risk of this WAL getting concurrently garbaged if - // another compute rises which collects majority and starts fixing log - // on this safekeeper itself. That's ok as (old) proposer will never be - // able to commit such WAL. - let stop_pos: Option = if spg.is_walproposer_recovery() { - let wal_end = tli.get_flush_lsn(); - Some(wal_end) } else { - None - }; + // Wait for the next portion if it is not there yet, or just + // update our end of WAL available for sending value, we + // communicate it to the receiver. + self.wait_wal().await?; + } - info!("Start replication from {:?} till {:?}", start_pos, stop_pos); + // try to send as much as available, capped by MAX_SEND_SIZE + let mut send_size = self + .end_pos + .checked_sub(self.start_pos) + .context("reading wal without waiting for it first")? + .0 as usize; + send_size = min(send_size, self.send_buf.len()); + let send_buf = &mut self.send_buf[..send_size]; + // read wal into buffer + send_size = self.wal_reader.read(send_buf).await?; + let send_buf = &send_buf[..send_size]; - // switch to copy - pgb.write_message(&BeMessage::CopyBothResponse)?; - - let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn); - - let mut wal_reader = WalReader::new( - spg.conf.workdir.clone(), - spg.conf.timeline_dir(&tli.ttid), - &persisted_state, - start_pos, - spg.conf.wal_backup_enabled, - )?; - - // buffer for wal sending, limited by MAX_SEND_SIZE - let mut send_buf = vec![0u8; MAX_SEND_SIZE]; - - // watcher for commit_lsn updates - let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx(); - - loop { - if let Some(stop_pos) = stop_pos { - if start_pos >= stop_pos { - break; /* recovery finished */ - } - end_pos = stop_pos; - } else { - /* Wait until we have some data to stream */ - let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?; - - if let Some(lsn) = lsn { - end_pos = lsn; - } else { - // TODO: also check once in a while whether we are walsender - // to right pageserver. - if tli.should_walsender_stop(replica_id) { - // Shut down, timeline is suspended. - return Err(QueryError::from(io::Error::new( - io::ErrorKind::ConnectionAborted, - format!("end streaming to {:?}", spg.appname), - ))); - } - - // timeout expired: request pageserver status - pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive { - sent_ptr: end_pos.0, - timestamp: get_current_timestamp(), - request_reply: true, - }))?; - continue; - } - } - - let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; - let send_size = min(send_size, send_buf.len()); - - let send_buf = &mut send_buf[..send_size]; - - // read wal into buffer - let send_size = wal_reader.read(send_buf).await?; - let send_buf = &send_buf[..send_size]; - - // Write some data to the network socket. - pgb.write_message(&BeMessage::XLogData(XLogDataBody { - wal_start: start_pos.0, - wal_end: end_pos.0, + // and send it + self.pgb + .write_message(&BeMessage::XLogData(XLogDataBody { + wal_start: self.start_pos.0, + wal_end: self.end_pos.0, timestamp: get_current_timestamp(), data: send_buf, })) - .context("Failed to send XLogData")?; + .await?; - start_pos += send_size as u64; - trace!("sent WAL up to {}", start_pos); + trace!( + "sent {} bytes of WAL {}-{}", + send_size, + self.start_pos, + self.start_pos + send_size as u64 + ); + self.start_pos += send_size as u64; + } + } + + /// wait until we have WAL to stream, sending keepalives and checking for + /// exit in the meanwhile + async fn wait_wal(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + if let Some(lsn) = wait_for_lsn(&mut self.commit_lsn_watch_rx, self.start_pos).await? { + self.end_pos = lsn; + return Ok(()); } + // Timed out waiting for WAL, check for termination and send KA + if self.tli.should_walsender_stop(self.replica_id) { + // Terminate if there is nothing more to send. + // TODO close the stream properly + return Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to {:?} at {}, receiver is caughtup and there is no computes", + self.appname, self.start_pos, + ))); + } + self.pgb + .write_message(&BeMessage::KeepAlive(WalSndKeepAlive { + sent_ptr: self.end_pos.0, + timestamp: get_current_timestamp(), + request_reply: true, + })) + .await?; + } + } +} - Ok(()) - }) +/// A half driving receiving replies. +struct ReplyReader { + reader: PostgresBackendReader, + tli: Arc, + replica_id: usize, + feedback: ReplicaState, +} + +impl ReplyReader { + async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + let msg = self.reader.read_copy_message().await?; + self.handle_feedback(&msg)? + } + } + + fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> { + match msg.first().cloned() { + Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { + // Note: deserializing is on m[1..] because we skip the tag byte. + self.feedback.hs_feedback = HotStandbyFeedback::des(&msg[1..]) + .context("failed to deserialize HotStandbyFeedback")?; + self.tli + .update_replica_state(self.replica_id, self.feedback); + } + Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { + let _reply = + StandbyReply::des(&msg[1..]).context("failed to deserialize StandbyReply")?; + // This must be a regular postgres replica, + // because pageserver doesn't send this type of messages to safekeeper. + // Currently we just ignore this, tracking progress for them is not supported. + } + Some(NEON_STATUS_UPDATE_TAG_BYTE) => { + // pageserver sends this. + // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. + let buf = Bytes::copy_from_slice(&msg[9..]); + let reply = ReplicationFeedback::parse(buf); + + trace!("ReplicationFeedback is {:?}", reply); + // Only pageserver sends ReplicationFeedback, so set the flag. + // This replica is the source of information to resend to compute. + self.feedback.pageserver_feedback = Some(reply); + + self.tli + .update_replica_state(self.replica_id, self.feedback); + } + _ => warn!("unexpected message {:?}", msg), + } + Ok(()) } } const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); -// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn. +/// Wait until we have commit_lsn > lsn or timeout expires. Returns +/// - Ok(Some(commit_lsn)) if needed lsn is successfully observed; +/// - Ok(None) if timeout expired; +/// - Err in case of error (if watch channel is in trouble, shouldn't happen). async fn wait_for_lsn(rx: &mut Receiver, lsn: Lsn) -> anyhow::Result> { let commit_lsn: Lsn = *rx.borrow(); if commit_lsn > lsn { diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 43c395574f..d5b849019e 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -1,4 +1,4 @@ -//! This module implements Timeline lifecycle management and has all neccessary code +//! This module implements Timeline lifecycle management and has all necessary code //! to glue together SafeKeeper and all other background services. use anyhow::{bail, Result}; @@ -518,7 +518,7 @@ impl Timeline { /// Register compute connection, starting timeline-related activity if it is /// not running yet. - pub fn on_compute_connect(&self) -> Result<()> { + pub async fn on_compute_connect(&self) -> Result<()> { if self.is_cancelled() { bail!(TimelineError::Cancelled(self.ttid)); } @@ -532,7 +532,7 @@ impl Timeline { // Wake up wal backup launcher, if offloading not started yet. if is_wal_backup_action_pending { // Can fail only if channel to a static thread got closed, which is not normal at all. - self.wal_backup_launcher_tx.blocking_send(self.ttid)?; + self.wal_backup_launcher_tx.send(self.ttid).await?; } Ok(()) } @@ -549,6 +549,11 @@ impl Timeline { // Wake up wal backup launcher, if it is time to stop the offloading. if is_wal_backup_action_pending { // Can fail only if channel to a static thread got closed, which is not normal at all. + // + // Note: this is blocking_send because on_compute_disconnect is called in Drop, there is + // no async Drop and we use current thread runtimes. With current thread rt spawning + // task in drop impl is racy, as thread along with runtime might finish before the task. + // This should be switched send.await when/if we go to full async. self.wal_backup_launcher_tx.blocking_send(self.ttid)?; } Ok(()) diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index 66e0145042..fcf5ab6302 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -161,7 +161,7 @@ impl GlobalTimelines { /// Create a new timeline with the given id. If the timeline already exists, returns /// an existing timeline. - pub fn create( + pub async fn create( ttid: TenantTimelineId, server_info: ServerInfo, commit_lsn: Lsn, @@ -189,28 +189,20 @@ impl GlobalTimelines { // Take a lock and finish the initialization holding this mutex. No other threads // can interfere with creation after we will insert timeline into the map. - let mut shared_state = timeline.write_shared_state(); + { + let mut shared_state = timeline.write_shared_state(); - // We can get a race condition here in case of concurrent create calls, but only - // in theory. create() will return valid timeline on the next try. - TIMELINES_STATE - .lock() - .unwrap() - .try_insert(timeline.clone())?; + // We can get a race condition here in case of concurrent create calls, but only + // in theory. create() will return valid timeline on the next try. + TIMELINES_STATE + .lock() + .unwrap() + .try_insert(timeline.clone())?; - // Write the new timeline to the disk and start background workers. - // Bootstrap is transactional, so if it fails, the timeline will be deleted, - // and the state on disk should remain unchanged. - match timeline.bootstrap(&mut shared_state) { - Ok(_) => { - // We are done with bootstrap, release the lock, return the timeline. - drop(shared_state); - timeline - .wal_backup_launcher_tx - .blocking_send(timeline.ttid)?; - Ok(timeline) - } - Err(e) => { + // Write the new timeline to the disk and start background workers. + // Bootstrap is transactional, so if it fails, the timeline will be deleted, + // and the state on disk should remain unchanged. + if let Err(e) = timeline.bootstrap(&mut shared_state) { // Note: the most likely reason for bootstrap failure is that the timeline // directory already exists on disk. This happens when timeline is corrupted // and wasn't loaded from disk on startup because of that. We want to preserve @@ -222,9 +214,13 @@ impl GlobalTimelines { // Timeline failed to bootstrap, it cannot be used. Remove it from the map. TIMELINES_STATE.lock().unwrap().timelines.remove(&ttid); - Err(e) + return Err(e); } + // We are done with bootstrap, release the lock, return the timeline. + // {} block forces release before .await } + timeline.wal_backup_launcher_tx.send(timeline.ttid).await?; + Ok(timeline) } /// Get a timeline from the global map. If it's not present, it doesn't exist on disk, @@ -244,7 +240,7 @@ impl GlobalTimelines { } } - /// Returns all timelines. This is used for background timeline proccesses. + /// Returns all timelines. This is used for background timeline processes. pub fn get_all() -> Vec> { let global_lock = TIMELINES_STATE.lock().unwrap(); global_lock diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index fc971ca753..798b9abaf3 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -191,7 +191,7 @@ async fn wal_backup_launcher_main_loop( .map(|c| GenericRemoteStorage::from_config(c).expect("failed to create remote storage")) }); - // Presense in this map means launcher is aware s3 offloading is needed for + // Presence in this map means launcher is aware s3 offloading is needed for // the timeline, but task is started only if it makes sense for to offload // from this safekeeper. let mut tasks: HashMap = HashMap::new(); @@ -467,7 +467,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize pub async fn read_object( file_path: &RemotePath, offset: u64, -) -> anyhow::Result>> { +) -> anyhow::Result>> { let storage = REMOTE_STORAGE .get() .context("Failed to get remote storage")? diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 3ca651d060..8d63d604ad 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -2,50 +2,65 @@ //! WAL service listens for client connections and //! receive WAL from wal_proposer and send it to WAL receivers //! -use regex::Regex; -use std::net::{TcpListener, TcpStream}; -use std::thread; +use anyhow::{Context, Result}; +use nix::unistd::gettid; +use postgres_backend::QueryError; +use std::{future, thread}; +use tokio::net::TcpStream; use tracing::*; -use utils::postgres_backend_async::QueryError; use crate::handler::SafekeeperPostgresHandler; use crate::SafeKeeperConf; -use utils::postgres_backend::{AuthType, PostgresBackend}; +use postgres_backend::{AuthType, PostgresBackend}; /// Accept incoming TCP connections and spawn them into a background thread. -pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! { - loop { - match listener.accept() { - Ok((socket, peer_addr)) => { - debug!("accepted connection from {}", peer_addr); - let conf = conf.clone(); +pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .context("create runtime") + // todo catch error in main thread + .expect("failed to create runtime"); - let _ = thread::Builder::new() - .name("WAL service thread".into()) - .spawn(move || { - if let Err(err) = handle_socket(socket, conf) { - error!("connection handler exited: {}", err); - } - }) - .unwrap(); + runtime + .block_on(async move { + // Tokio's from_std won't do this for us, per its comment. + pg_listener.set_nonblocking(true)?; + let listener = tokio::net::TcpListener::from_std(pg_listener)?; + + loop { + match listener.accept().await { + Ok((socket, peer_addr)) => { + debug!("accepted connection from {}", peer_addr); + let conf = conf.clone(); + + let _ = thread::Builder::new() + .name("WAL service thread".into()) + .spawn(move || { + if let Err(err) = handle_socket(socket, conf) { + error!("connection handler exited: {}", err); + } + }) + .unwrap(); + } + Err(e) => error!("Failed to accept connection: {}", e), + } } - Err(e) => error!("Failed to accept connection: {}", e), - } - } -} - -// Get unique thread id (Rust internal), with ThreadId removed for shorter printing -fn get_tid() -> u64 { - let tids = format!("{:?}", thread::current().id()); - let r = Regex::new(r"ThreadId\((\d+)\)").unwrap(); - let caps = r.captures(&tids).unwrap(); - caps.get(1).unwrap().as_str().parse().unwrap() + #[allow(unreachable_code)] // hint compiler the closure return type + Ok::<(), anyhow::Error>(()) + }) + .expect("listener failed") } /// This is run by `thread_main` above, inside a background thread. /// fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> { - let _enter = info_span!("", tid = ?get_tid()).entered(); + let _enter = info_span!("", tid = %gettid()).entered(); + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + let local = tokio::task::LocalSet::new(); socket.set_nodelay(true)?; @@ -54,9 +69,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr Some(_) => AuthType::NeonJWT, }; let mut conn_handler = SafekeeperPostgresHandler::new(conf); - let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?; - // libpq replication protocol between safekeeper and replicas/pagers - pgbackend.run(&mut conn_handler)?; + let pgbackend = PostgresBackend::new(socket, auth_type, None)?; + // libpq protocol between safekeeper and walproposer / pageserver + // We don't use shutdown. + local.block_on( + &runtime, + pgbackend.run(&mut conn_handler, future::pending::<()>), + )?; Ok(()) } diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index 561104bd27..e83f72a3cf 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -461,7 +461,7 @@ pub struct WalReader { timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn, - wal_segment: Option>>, + wal_segment: Option>>, // S3 will be used to read WAL if LSN is not available locally enable_remote_read: bool, @@ -528,7 +528,7 @@ impl WalReader { } /// Open WAL segment at the current position of the reader. - async fn open_segment(&self) -> Result>> { + async fn open_segment(&self) -> Result>> { let xlogoff = self.pos.segment_offset(self.wal_seg_size); let segno = self.pos.segment_number(self.wal_seg_size); let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size); diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index c4b3d057f8..611361eaf1 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2049,8 +2049,10 @@ class NeonPageserver(PgProtocol): ".*Connection aborted: connection error: error communicating with the server: Broken pipe.*", ".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*", ".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*", + # FIXME: replication patch for tokio_postgres regards any but CopyDone/CopyData message in CopyBoth stream as unexpected + ".*Connection aborted: connection error: unexpected message from server*", ".*kill_and_wait_impl.*: wait successful.*", - ".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*", + ".*Replication stream finished: db error:.*ending streaming to Some*", ".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down ".*query handler for 'pagestream.*failed: Connection reset by peer.*", # pageserver notices compute shut down # safekeeper connection can fail with this, in the window between timeline creation diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 9e3b0ec02f..cea2125f4f 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -1106,8 +1106,8 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): # FIXME: are these expected? env.pageserver.allowed_errors.extend( [ - ".*Failed to process query for timeline .*: Timeline .* was not found in global map.*", - ".*Failed to process query for timeline .*: Timeline .* was cancelled and cannot be used anymore.*", + ".*Timeline .* was not found in global map.*", + ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index bd21095fff..28e8e4149c 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -14,14 +14,19 @@ publish = false ### BEGIN HAKARI SECTION [dependencies] anyhow = { version = "1", features = ["backtrace"] } +byteorder = { version = "1" } bytes = { version = "1", features = ["serde"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } clap = { version = "4", features = ["derive", "string"] } crossbeam-utils = { version = "0.8" } +digest = { version = "0.10", features = ["mac", "std"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures = { version = "0.3" } +futures-channel = { version = "0.3", features = ["sink"] } +futures-core = { version = "0.3" } futures-executor = { version = "0.3" } +futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } hashbrown = { version = "0.12", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] }