From 06b97c60f3f4318f4eb0ca6410e4f0e8d34a7a25 Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Thu, 2 Feb 2023 12:03:45 +0400 Subject: [PATCH] Remove sync postgres_backend_async, tidy up its split usage. - Use postgres_backend_async throughout safekeeper. - Use framed.rs in postgres_backend_async, similar to tokio_util::codec::Framed but with slightly more efficient split. Now IO functions are also cancellation safe. Also, there is no allocation in each message read anymore, and data in read messages still directly point to input buffer without copies. - In both safekeeper COPY streams, do read-write from the same thread/task with select! for easier error handling. - Tidy up finishing CopyBoth streams in safekeeper sending and receiving WAL -- join split parts back catching errors from them before returning. Initially I hoped to do that read-write without split at all, through polling IO: https://github.com/neondatabase/neon/pull/3522 However that turned out to be more complicated than I initially expected due to 1) borrow checking and 2) anon Future types. 1) required Rc> which is Send construct just to satisfy the checker; 2) can be workaround with transmute. But this is so messy that I decided to leave split. proxy stream.rs is adapted with minimal changes. It also benefits from framed.rs improvements described above. --- Cargo.lock | 46 +- Cargo.toml | 1 + control_plane/Cargo.toml | 1 + control_plane/src/bin/neon_local.rs | 2 +- control_plane/src/compute.rs | 2 +- control_plane/src/local_env.rs | 2 +- control_plane/src/pageserver.rs | 2 +- libs/postgres_backend/Cargo.toml | 26 + libs/postgres_backend/src/lib.rs | 911 ++++++++++++++++++ .../tests/cert.pem | 0 .../{utils => postgres_backend}/tests/key.pem | 0 libs/postgres_backend/tests/simple_select.rs | 139 +++ libs/pq_proto/Cargo.toml | 2 +- libs/pq_proto/src/framed.rs | 251 +++++ libs/pq_proto/src/lib.rs | 505 +++++----- libs/pq_proto/src/sync.rs | 179 ---- libs/remote_storage/src/lib.rs | 2 +- libs/utils/Cargo.toml | 24 +- libs/utils/src/lib.rs | 5 - libs/utils/src/postgres_backend.rs | 485 ---------- libs/utils/src/postgres_backend_async.rs | 636 ------------ libs/utils/src/sock_split.rs | 206 ---- libs/utils/tests/ssl_test.rs | 238 ----- pageserver/Cargo.toml | 1 + pageserver/src/bin/pageserver.rs | 5 +- pageserver/src/config.rs | 2 +- pageserver/src/page_service.rs | 30 +- .../walreceiver/walreceiver_connection.rs | 9 +- proxy/Cargo.toml | 1 + proxy/src/console/mgmt.rs | 43 +- proxy/src/stream.rs | 61 +- run_clippy.sh | 10 +- safekeeper/Cargo.toml | 1 + safekeeper/src/bin/safekeeper.rs | 2 +- safekeeper/src/handler.rs | 62 +- safekeeper/src/http/routes.rs | 10 +- safekeeper/src/json_ctrl.rs | 41 +- safekeeper/src/lib.rs | 3 +- safekeeper/src/receive_wal.rs | 421 ++++---- safekeeper/src/safekeeper.rs | 2 +- safekeeper/src/send_wal.rs | 453 +++++---- safekeeper/src/timeline.rs | 11 +- safekeeper/src/timelines_global_map.rs | 42 +- safekeeper/src/wal_backup.rs | 4 +- safekeeper/src/wal_service.rs | 87 +- safekeeper/src/wal_storage.rs | 4 +- test_runner/fixtures/neon_fixtures.py | 4 +- test_runner/regress/test_wal_acceptor.py | 4 +- workspace_hack/Cargo.toml | 5 + 49 files changed, 2345 insertions(+), 2638 deletions(-) create mode 100644 libs/postgres_backend/Cargo.toml create mode 100644 libs/postgres_backend/src/lib.rs rename libs/{utils => postgres_backend}/tests/cert.pem (100%) rename libs/{utils => postgres_backend}/tests/key.pem (100%) create mode 100644 libs/postgres_backend/tests/simple_select.rs create mode 100644 libs/pq_proto/src/framed.rs delete mode 100644 libs/pq_proto/src/sync.rs delete mode 100644 libs/utils/src/postgres_backend.rs delete mode 100644 libs/utils/src/postgres_backend_async.rs delete mode 100644 libs/utils/src/sock_split.rs delete mode 100644 libs/utils/tests/ssl_test.rs diff --git a/Cargo.lock b/Cargo.lock index 02b03e02fb..0a505573e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -913,6 +913,7 @@ dependencies = [ "once_cell", "pageserver_api", "postgres", + "postgres_backend", "postgres_connection", "regex", "reqwest", @@ -2454,6 +2455,7 @@ dependencies = [ "postgres", "postgres-protocol", "postgres-types", + "postgres_backend", "postgres_connection", "postgres_ffi", "pq_proto", @@ -2676,6 +2678,28 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "postgres_backend" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bytes", + "futures", + "once_cell", + "pq_proto", + "rustls", + "rustls-pemfile", + "serde", + "thiserror", + "tokio", + "tokio-postgres", + "tokio-postgres-rustls", + "tokio-rustls", + "tracing", + "workspace_hack", +] + [[package]] name = "postgres_connection" version = "0.1.0" @@ -2723,7 +2747,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" name = "pq_proto" version = "0.1.0" dependencies = [ - "anyhow", + "byteorder", "bytes", "pin-project-lite", "postgres-protocol", @@ -2898,6 +2922,7 @@ dependencies = [ "opentelemetry", "parking_lot", "pin-project-lite", + "postgres_backend", "pq_proto", "prometheus", "rand", @@ -3277,15 +3302,6 @@ dependencies = [ "base64 0.21.0", ] -[[package]] -name = "rustls-split" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3" -dependencies = [ - "rustls", -] - [[package]] name = "rustversion" version = "1.0.11" @@ -3321,6 +3337,7 @@ dependencies = [ "parking_lot", "postgres", "postgres-protocol", + "postgres_backend", "postgres_ffi", "pq_proto", "regex", @@ -4506,7 +4523,6 @@ dependencies = [ "bytes", "criterion", "futures", - "git-version", "heapless", "hex", "hex-literal", @@ -4515,12 +4531,9 @@ dependencies = [ "metrics", "nix", "once_cell", - "pq_proto", "rand", "routerify", "rustls", - "rustls-pemfile", - "rustls-split", "sentry", "serde", "serde_json", @@ -4835,14 +4848,19 @@ name = "workspace_hack" version = "0.1.0" dependencies = [ "anyhow", + "byteorder", "bytes", "chrono", "clap 4.1.4", "crossbeam-utils", + "digest", "either", "fail", "futures", + "futures-channel", + "futures-core", "futures-executor", + "futures-sink", "futures-util", "hashbrown 0.12.3", "indexmap", diff --git a/Cargo.toml b/Cargo.toml index ea22b04124..bbd4975603 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -133,6 +133,7 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" } metrics = { version = "0.1", path = "./libs/metrics/" } pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" } +postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" } postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" } postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 309887e1fa..ba39747e03 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -24,6 +24,7 @@ url.workspace = true # Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api # instead, so that recompile times are better. pageserver_api.workspace = true +postgres_backend.workspace = true safekeeper_api.workspace = true postgres_connection.workspace = true storage_broker.workspace = true diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 4b2aa3c957..49b1d31dbc 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -17,6 +17,7 @@ use pageserver_api::{ DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR, DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR, }; +use postgres_backend::AuthType; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT, @@ -30,7 +31,6 @@ use utils::{ auth::{Claims, Scope}, id::{NodeId, TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, project_git_version, }; diff --git a/control_plane/src/compute.rs b/control_plane/src/compute.rs index 8731cf2583..b7029aabc5 100644 --- a/control_plane/src/compute.rs +++ b/control_plane/src/compute.rs @@ -11,10 +11,10 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{Context, Result}; +use postgres_backend::AuthType; use utils::{ id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, }; use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION}; diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 003152c578..09180d96c4 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -5,6 +5,7 @@ use anyhow::{bail, ensure, Context}; +use postgres_backend::AuthType; use reqwest::Url; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -19,7 +20,6 @@ use std::process::{Command, Stdio}; use utils::{ auth::{encode_from_key_file, Claims, Scope}, id::{NodeId, TenantId, TenantTimelineId, TimelineId}, - postgres_backend::AuthType, }; use crate::safekeeper::SafekeeperNode; diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index c49bd39f09..4b7180c250 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -11,6 +11,7 @@ use anyhow::{bail, Context}; use pageserver_api::models::{ TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo, }; +use postgres_backend::AuthType; use postgres_connection::{parse_host_port, PgConnectionConfig}; use reqwest::blocking::{Client, RequestBuilder, Response}; use reqwest::{IntoUrl, Method}; @@ -20,7 +21,6 @@ use utils::{ http::error::HttpErrorBody, id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, }; use crate::{background_process, local_env::LocalEnv}; diff --git a/libs/postgres_backend/Cargo.toml b/libs/postgres_backend/Cargo.toml new file mode 100644 index 0000000000..8e249c09f7 --- /dev/null +++ b/libs/postgres_backend/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "postgres_backend" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +async-trait.workspace = true +anyhow.workspace = true +bytes.workspace = true +futures.workspace = true +rustls.workspace = true +serde.workspace = true +thiserror.workspace = true +tokio.workspace = true +tokio-rustls.workspace = true +tracing.workspace = true + +pq_proto.workspace = true +workspace_hack.workspace = true + +[dev-dependencies] +once_cell.workspace = true +rustls-pemfile.workspace = true +tokio-postgres.workspace = true +tokio-postgres-rustls.workspace = true \ No newline at end of file diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs new file mode 100644 index 0000000000..15a9fc6dd0 --- /dev/null +++ b/libs/postgres_backend/src/lib.rs @@ -0,0 +1,911 @@ +//! Server-side asynchronous Postgres connection, as limited as we need. +//! To use, create PostgresBackend and run() it, passing the Handler +//! implementation determining how to process the queries. Currently its API +//! is rather narrow, but we can extend it once required. +use anyhow::Context; +use bytes::Bytes; +use futures::pin_mut; +use serde::{Deserialize, Serialize}; +use std::io::ErrorKind; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Poll}; +use std::{fmt, io}; +use std::{future::Future, str::FromStr}; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio_rustls::TlsAcceptor; + +use tracing::{debug, error, info, trace}; + +use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter}; +use pq_proto::{ + BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR, + SQLSTATE_SUCCESSFUL_COMPLETION, +}; + +/// An error, occurred during query processing: +/// either during the connection ([`ConnectionError`]) or before/after it. +#[derive(thiserror::Error, Debug)] +pub enum QueryError { + /// The connection was lost while processing the query. + #[error(transparent)] + Disconnected(#[from] ConnectionError), + /// Some other error + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for QueryError { + fn from(e: io::Error) -> Self { + Self::Disconnected(ConnectionError::Io(e)) + } +} + +impl QueryError { + pub fn pg_error_code(&self) -> &'static [u8; 5] { + match self { + Self::Disconnected(_) => b"08006", // connection failure + Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error + } + } +} + +pub fn is_expected_io_error(e: &io::Error) -> bool { + use io::ErrorKind::*; + matches!( + e.kind(), + ConnectionRefused | ConnectionAborted | ConnectionReset + ) +} + +#[async_trait::async_trait] +pub trait Handler { + /// Handle single query. + /// postgres_backend will issue ReadyForQuery after calling this (this + /// might be not what we want after CopyData streaming, but currently we don't + /// care). It will also flush out the output buffer. + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query_string: &str, + ) -> Result<(), QueryError>; + + /// Called on startup packet receival, allows to process params. + /// + /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users + /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow + /// to override whole init logic in implementations. + fn startup( + &mut self, + _pgb: &mut PostgresBackend, + _sm: &FeStartupPacket, + ) -> Result<(), QueryError> { + Ok(()) + } + + /// Check auth jwt + fn check_auth_jwt( + &mut self, + _pgb: &mut PostgresBackend, + _jwt_response: &[u8], + ) -> Result<(), QueryError> { + Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) + } +} + +/// PostgresBackend protocol state. +/// XXX: The order of the constructors matters. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum ProtoState { + /// Nothing happened yet. + Initialization, + /// Encryption handshake is done; waiting for encrypted Startup message. + Encrypted, + /// Waiting for password (auth token). + Authentication, + /// Performed handshake and auth, ReadyForQuery is issued. + Established, + Closed, +} + +#[derive(Clone, Copy)] +pub enum ProcessMsgResult { + Continue, + Break, +} + +/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite. +pub enum MaybeTlsStream { + Unencrypted(tokio::net::TcpStream), + Tls(Box>), +} + +impl AsyncWrite for MaybeTlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf), + Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx), + Self::Tls(stream) => Pin::new(stream).poll_flush(cx), + } + } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx), + Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} +impl AsyncRead for MaybeTlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf), + Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum AuthType { + Trust, + // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT + NeonJWT, +} + +impl FromStr for AuthType { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "Trust" => Ok(Self::Trust), + "NeonJWT" => Ok(Self::NeonJWT), + _ => anyhow::bail!("invalid value \"{s}\" for auth type"), + } + } +} + +impl fmt::Display for AuthType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + AuthType::Trust => "Trust", + AuthType::NeonJWT => "NeonJWT", + }) + } +} + +/// Either full duplex Framed or write only half; the latter is left in +/// PostgresBackend after call to `split`. In principle we could always store a +/// pair of splitted handles, but that would force to to pay splitting price +/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver). +enum MaybeWriteOnly { + Full(Framed), + WriteOnly(FramedWriter>), + Broken, // temporary value palmed off during the split +} + +impl MaybeWriteOnly { + async fn read_startup_message(&mut self) -> Result, ConnectionError> { + match self { + MaybeWriteOnly::Full(framed) => framed.read_startup_message().await, + MaybeWriteOnly::WriteOnly(_) => { + Err(io::Error::new(ErrorKind::Other, "reading from write only half").into()) + } + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn read_message(&mut self) -> Result, ConnectionError> { + match self { + MaybeWriteOnly::Full(framed) => framed.read_message().await, + MaybeWriteOnly::WriteOnly(_) => { + Err(io::Error::new(ErrorKind::Other, "reading from write only half").into()) + } + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + match self { + MaybeWriteOnly::Full(framed) => framed.write_message(msg), + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg), + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn flush(&mut self) -> io::Result<()> { + match self { + MaybeWriteOnly::Full(framed) => framed.flush().await, + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await, + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } + + async fn shutdown(&mut self) -> io::Result<()> { + match self { + MaybeWriteOnly::Full(framed) => framed.shutdown().await, + MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await, + MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"), + } + } +} + +pub struct PostgresBackend { + framed: MaybeWriteOnly, + + pub state: ProtoState, + + auth_type: AuthType, + + peer_addr: SocketAddr, + pub tls_config: Option>, +} + +pub fn query_from_cstring(query_string: Bytes) -> Vec { + let mut query_string = query_string.to_vec(); + if let Some(ch) = query_string.last() { + if *ch == 0 { + query_string.pop(); + } + } + query_string +} + +/// Cast a byte slice to a string slice, dropping null terminator if there's one. +fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { + let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); + std::str::from_utf8(without_null).map_err(|e| e.into()) +} + +impl PostgresBackend { + pub fn new( + socket: tokio::net::TcpStream, + auth_type: AuthType, + tls_config: Option>, + ) -> io::Result { + let peer_addr = socket.peer_addr()?; + let stream = MaybeTlsStream::Unencrypted(socket); + + Ok(Self { + framed: MaybeWriteOnly::Full(Framed::new(stream)), + state: ProtoState::Initialization, + auth_type, + tls_config, + peer_addr, + }) + } + + pub fn get_peer_addr(&self) -> &SocketAddr { + &self.peer_addr + } + + /// Read full message or return None if connection is cleanly closed with no + /// unprocessed data. + pub async fn read_message(&mut self) -> Result, ConnectionError> { + if let ProtoState::Closed = self.state { + Ok(None) + } else { + let m = self.framed.read_message().await?; + trace!("read msg {:?}", m); + Ok(m) + } + } + + /// Write message into internal output buffer, doesn't flush it. Technically + /// error type can be only ProtocolError here (if, unlikely, serialization + /// fails), but callers typically wrap it anyway. + pub fn write_message_noflush( + &mut self, + message: &BeMessage<'_>, + ) -> Result<&mut Self, ConnectionError> { + self.framed.write_message_noflush(message)?; + trace!("wrote msg {:?}", message); + Ok(self) + } + + /// Flush output buffer into the socket. + pub async fn flush(&mut self) -> io::Result<()> { + self.framed.flush().await + } + + /// Polling version of `flush()`, saves the caller need to pin. + pub fn poll_flush( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let flush_fut = self.flush(); + pin_mut!(flush_fut); + flush_fut.poll(cx) + } + + /// Write message into internal output buffer and flush it to the stream. + pub async fn write_message( + &mut self, + message: &BeMessage<'_>, + ) -> Result<&mut Self, ConnectionError> { + self.write_message_noflush(message)?; + self.flush().await?; + Ok(self) + } + + /// Returns an AsyncWrite implementation that wraps all the data written + /// to it in CopyData messages, and writes them to the connection + /// + /// The caller is responsible for sending CopyOutResponse and CopyDone messages. + pub fn copyout_writer(&mut self) -> CopyDataWriter { + CopyDataWriter { pgb: self } + } + + /// Wrapper for run_message_loop() that shuts down socket when we are done + pub async fn run( + mut self, + handler: &mut impl Handler, + shutdown_watcher: F, + ) -> Result<(), QueryError> + where + F: Fn() -> S, + S: Future, + { + let ret = self.run_message_loop(handler, shutdown_watcher).await; + // socket might be already closed, e.g. if previously received error, + // so ignore result. + self.framed.shutdown().await.ok(); + ret + } + + async fn run_message_loop( + &mut self, + handler: &mut impl Handler, + shutdown_watcher: F, + ) -> Result<(), QueryError> + where + F: Fn() -> S, + S: Future, + { + trace!("postgres backend to {:?} started", self.peer_addr); + + tokio::select!( + biased; + + _ = shutdown_watcher() => { + // We were requested to shut down. + tracing::info!("shutdown request received during handshake"); + return Ok(()) + }, + + result = self.handshake(handler) => { + // Handshake complete. + result?; + if self.state == ProtoState::Closed { + return Ok(()); // EOF during handshake + } + } + ); + + // Authentication completed + let mut query_string = Bytes::new(); + while let Some(msg) = tokio::select!( + biased; + _ = shutdown_watcher() => { + // We were requested to shut down. + tracing::info!("shutdown request received in run_message_loop"); + Ok(None) + }, + msg = self.read_message() => { msg }, + )? { + trace!("got message {:?}", msg); + + let result = self.process_message(handler, msg, &mut query_string).await; + self.flush().await?; + match result? { + ProcessMsgResult::Continue => { + self.flush().await?; + continue; + } + ProcessMsgResult::Break => break, + } + } + + trace!("postgres backend to {:?} exited", self.peer_addr); + Ok(()) + } + + /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake. + async fn tls_upgrade( + src: MaybeTlsStream, + tls_config: Arc, + ) -> anyhow::Result { + match src { + MaybeTlsStream::Unencrypted(s) => { + let acceptor = TlsAcceptor::from(tls_config); + let tls_stream = acceptor.accept(s).await?; + Ok(MaybeTlsStream::Tls(Box::new(tls_stream))) + } + MaybeTlsStream::Tls(_) => { + anyhow::bail!("TLS already started"); + } + } + } + + async fn start_tls(&mut self) -> anyhow::Result<()> { + // temporary replace stream with fake to cook TLS one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(framed) => { + let tls_config = self + .tls_config + .as_ref() + .context("start_tls called without conf")? + .clone(); + let tls_framed = framed + .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config)) + .await?; + // push back ready TLS stream + self.framed = MaybeWriteOnly::Full(tls_framed); + Ok(()) + } + MaybeWriteOnly::WriteOnly(_) => { + anyhow::bail!("TLS upgrade attempt in split state") + } + MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"), + } + } + + /// Split off owned read part from which messages can be read in different + /// task/thread. + pub fn split(&mut self) -> anyhow::Result { + // temporary replace stream with fake to cook split one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(framed) => { + let (reader, writer) = framed.split(); + self.framed = MaybeWriteOnly::WriteOnly(writer); + Ok(PostgresBackendReader(reader)) + } + MaybeWriteOnly::WriteOnly(_) => { + anyhow::bail!("PostgresBackend is already split") + } + MaybeWriteOnly::Broken => panic!("split on framed in invalid state"), + } + } + + /// Join read part back. + pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> { + // temporary replace stream with fake to cook joined one, Indiana Jones style + match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) { + MaybeWriteOnly::Full(_) => { + anyhow::bail!("PostgresBackend is not split") + } + MaybeWriteOnly::WriteOnly(writer) => { + let joined = Framed::unsplit(reader.0, writer); + self.framed = MaybeWriteOnly::Full(joined); + Ok(()) + } + MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"), + } + } + + /// Perform handshake with the client, transitioning to Established. + /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()). + async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> { + while self.state < ProtoState::Authentication { + match self.framed.read_startup_message().await? { + Some(msg) => { + self.process_startup_message(handler, msg).await?; + } + None => { + trace!( + "postgres backend to {:?} received EOF during handshake", + self.peer_addr + ); + self.state = ProtoState::Closed; + return Ok(()); + } + } + } + + // Perform auth, if needed. + if self.state == ProtoState::Authentication { + match self.framed.read_message().await? { + Some(FeMessage::PasswordMessage(m)) => { + assert!(self.auth_type == AuthType::NeonJWT); + + let (_, jwt_response) = m.split_last().context("protocol violation")?; + + if let Err(e) = handler.check_auth_jwt(self, jwt_response) { + self.write_message_noflush(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; + return Err(e); + } + + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeMessage::CLIENT_ENCODING)? + .write_message(&BeMessage::ReadyForQuery) + .await?; + self.state = ProtoState::Established; + } + Some(m) => { + return Err(QueryError::Other(anyhow::anyhow!( + "Unexpected message {:?} while waiting for handshake", + m + ))); + } + None => { + trace!( + "postgres backend to {:?} received EOF during auth", + self.peer_addr + ); + self.state = ProtoState::Closed; + return Ok(()); + } + } + } + + Ok(()) + } + + /// Process startup packet: + /// - transition to Established if auth type is trust + /// - transition to Authentication if auth type is NeonJWT. + /// - or perform TLS handshake -- then need to call this again to receive + /// actual startup packet. + async fn process_startup_message( + &mut self, + handler: &mut impl Handler, + msg: FeStartupPacket, + ) -> Result<(), QueryError> { + assert!(self.state < ProtoState::Authentication); + let have_tls = self.tls_config.is_some(); + match msg { + FeStartupPacket::SslRequest => { + debug!("SSL requested"); + + self.write_message(&BeMessage::EncryptionResponse(have_tls)) + .await?; + + if have_tls { + self.start_tls().await?; + self.state = ProtoState::Encrypted; + } + } + FeStartupPacket::GssEncRequest => { + debug!("GSS requested"); + self.write_message(&BeMessage::EncryptionResponse(false)) + .await?; + } + FeStartupPacket::StartupMessage { .. } => { + if have_tls && !matches!(self.state, ProtoState::Encrypted) { + self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None)) + .await?; + return Err(QueryError::Other(anyhow::anyhow!( + "client did not connect with TLS" + ))); + } + + // NB: startup() may change self.auth_type -- we are using that in proxy code + // to bypass auth for new users. + handler.startup(self, &msg)?; + + match self.auth_type { + AuthType::Trust => { + self.write_message_noflush(&BeMessage::AuthenticationOk)? + .write_message_noflush(&BeMessage::CLIENT_ENCODING)? + .write_message_noflush(&BeMessage::INTEGER_DATETIMES)? + // The async python driver requires a valid server_version + .write_message_noflush(&BeMessage::server_version("14.1"))? + .write_message(&BeMessage::ReadyForQuery) + .await?; + self.state = ProtoState::Established; + } + AuthType::NeonJWT => { + self.write_message(&BeMessage::AuthenticationCleartextPassword) + .await?; + self.state = ProtoState::Authentication; + } + } + } + FeStartupPacket::CancelRequest { .. } => { + return Err(QueryError::Other(anyhow::anyhow!( + "Unexpected CancelRequest message during handshake" + ))); + } + } + Ok(()) + } + + async fn process_message( + &mut self, + handler: &mut impl Handler, + msg: FeMessage, + unnamed_query_string: &mut Bytes, + ) -> Result { + // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth + // TODO: change that to proper top-level match of protocol state with separate message handling for each state + assert!(self.state == ProtoState::Established); + + match msg { + FeMessage::Query(body) => { + // remove null terminator + let query_string = cstr_to_str(&body)?; + + trace!("got query {query_string:?}"); + if let Err(e) = handler.process_query(self, query_string).await { + log_query_error(query_string, &e); + let short_error = short_error(&e); + self.write_message_noflush(&BeMessage::ErrorResponse( + &short_error, + Some(e.pg_error_code()), + ))?; + } + self.write_message_noflush(&BeMessage::ReadyForQuery)?; + } + + FeMessage::Parse(m) => { + *unnamed_query_string = m.query_string; + self.write_message_noflush(&BeMessage::ParseComplete)?; + } + + FeMessage::Describe(_) => { + self.write_message_noflush(&BeMessage::ParameterDescription)? + .write_message_noflush(&BeMessage::NoData)?; + } + + FeMessage::Bind(_) => { + self.write_message_noflush(&BeMessage::BindComplete)?; + } + + FeMessage::Close(_) => { + self.write_message_noflush(&BeMessage::CloseComplete)?; + } + + FeMessage::Execute(_) => { + let query_string = cstr_to_str(unnamed_query_string)?; + trace!("got execute {query_string:?}"); + if let Err(e) = handler.process_query(self, query_string).await { + log_query_error(query_string, &e); + self.write_message_noflush(&BeMessage::ErrorResponse( + &e.to_string(), + Some(e.pg_error_code()), + ))?; + } + // NOTE there is no ReadyForQuery message. This handler is used + // for basebackup and it uses CopyOut which doesn't require + // ReadyForQuery message and backend just switches back to + // processing mode after sending CopyDone or ErrorResponse. + } + + FeMessage::Sync => { + self.write_message_noflush(&BeMessage::ReadyForQuery)?; + } + + FeMessage::Terminate => { + return Ok(ProcessMsgResult::Break); + } + + // We prefer explicit pattern matching to wildcards, because + // this helps us spot the places where new variants are missing + FeMessage::CopyData(_) + | FeMessage::CopyDone + | FeMessage::CopyFail + | FeMessage::PasswordMessage(_) => { + return Err(QueryError::Other(anyhow::anyhow!( + "unexpected message type: {msg:?}", + ))); + } + } + + Ok(ProcessMsgResult::Continue) + } + + /// Log as info/error result of handling COPY stream and send back + /// ErrorResponse if that makes sense. Shutdown the stream if we got + /// Terminate. TODO: transition into waiting for Sync msg if we initiate the + /// close. + pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) { + use CopyStreamHandlerEnd::*; + + let expected_end = match &end { + ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true, + CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error)) + if is_expected_io_error(io_error) => + { + true + } + _ => false, + }; + if expected_end { + info!("terminated: {:#}", end); + } else { + error!("terminated: {:?}", end); + } + + // Note: no current usages ever send this + if let CopyDone = &end { + if let Err(e) = self.write_message(&BeMessage::CopyDone).await { + error!("failed to send CopyDone: {}", e); + } + } + + if let Terminate = &end { + self.state = ProtoState::Closed; + } + + let err_to_send_and_errcode = match &end { + ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)), + Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)), + // Note: CopyFail in duplex copy is somewhat unexpected (at least to + // PG walsender; evidently and per my docs reading client should + // finish it with CopyDone). It is not a problem to recover from it + // finishing the stream in both directions like we do, but note that + // sync rust-postgres client (which we don't use anymore) hangs if + // socket is not closed here. + // https://github.com/sfackler/rust-postgres/issues/755 + // https://github.com/neondatabase/neon/issues/935 + // + // Currently, the version of tokio_postgres replication patch we use + // sends this when it closes the stream (e.g. pageserver decided to + // switch conn to another safekeeper and client gets dropped). + // Moreover, seems like 'connection' task errors with 'unexpected + // message from server' when it receives ErrorResponse (anything but + // CopyData/CopyDone) back. + CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)), + _ => None, + }; + if let Some((err, errcode)) = err_to_send_and_errcode { + if let Err(ee) = self + .write_message(&BeMessage::ErrorResponse(&err, Some(errcode))) + .await + { + error!("failed to send ErrorResponse: {}", ee); + } + } + } +} + +pub struct PostgresBackendReader(FramedReader>); + +impl PostgresBackendReader { + /// Read full message or return None if connection is cleanly closed with no + /// unprocessed data. + pub async fn read_message(&mut self) -> Result, ConnectionError> { + let m = self.0.read_message().await?; + trace!("read msg {:?}", m); + Ok(m) + } + + /// Get CopyData contents of the next message in COPY stream or error + /// closing it. The error type is wider than actual errors which can happen + /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for + /// current callers. + pub async fn read_copy_message(&mut self) -> Result { + match self.read_message().await? { + Some(msg) => match msg { + FeMessage::CopyData(m) => Ok(m), + FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone), + FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail), + FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate), + _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol( + ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)), + ))), + }, + None => Err(CopyStreamHandlerEnd::EOF), + } + } +} + +/// +/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData +/// messages. +/// + +pub struct CopyDataWriter<'a> { + pgb: &'a mut PostgresBackend, +} + +impl<'a> AsyncWrite for CopyDataWriter<'a> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + + // It's not strictly required to flush between each message, but makes it easier + // to view in wireshark, and usually the messages that the callers write are + // decently-sized anyway. + if let Err(err) = ready!(this.pgb.poll_flush(cx)) { + return Poll::Ready(Err(err)); + } + + // CopyData + // XXX: if the input is large, we should split it into multiple messages. + // Not sure what the threshold should be, but the ultimate hard limit is that + // the length cannot exceed u32. + this.pgb + .write_message_noflush(&BeMessage::CopyData(buf)) + // write_message only writes to the buffer, so it can fail iff the + // message is invaid, but CopyData can't be invalid. + .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?; + + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + this.pgb.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + this.pgb.poll_flush(cx) + } +} + +pub fn short_error(e: &QueryError) -> String { + match e { + QueryError::Disconnected(connection_error) => connection_error.to_string(), + QueryError::Other(e) => format!("{e:#}"), + } +} + +fn log_query_error(query: &str, e: &QueryError) { + match e { + QueryError::Disconnected(ConnectionError::Io(io_error)) => { + if is_expected_io_error(io_error) { + info!("query handler for '{query}' failed with expected io error: {io_error}"); + } else { + error!("query handler for '{query}' failed with io error: {io_error}"); + } + } + QueryError::Disconnected(other_connection_error) => { + error!("query handler for '{query}' failed with connection error: {other_connection_error:?}") + } + QueryError::Other(e) => { + error!("query handler for '{query}' failed: {e:?}"); + } + } +} + +/// Something finishing handling of COPY stream, see handle_copy_stream_end. +/// This is not always a real error, but it allows to use ? and thiserror impls. +#[derive(thiserror::Error, Debug)] +pub enum CopyStreamHandlerEnd { + /// Handler initiates the end of streaming. + #[error("{0}")] + ServerInitiated(String), + #[error("received CopyDone")] + CopyDone, + #[error("received CopyFail")] + CopyFail, + #[error("received Terminate")] + Terminate, + #[error("EOF on COPY stream")] + EOF, + /// The connection was lost + #[error(transparent)] + Disconnected(#[from] ConnectionError), + /// Some other error + #[error(transparent)] + Other(#[from] anyhow::Error), +} diff --git a/libs/utils/tests/cert.pem b/libs/postgres_backend/tests/cert.pem similarity index 100% rename from libs/utils/tests/cert.pem rename to libs/postgres_backend/tests/cert.pem diff --git a/libs/utils/tests/key.pem b/libs/postgres_backend/tests/key.pem similarity index 100% rename from libs/utils/tests/key.pem rename to libs/postgres_backend/tests/key.pem diff --git a/libs/postgres_backend/tests/simple_select.rs b/libs/postgres_backend/tests/simple_select.rs new file mode 100644 index 0000000000..a310171c70 --- /dev/null +++ b/libs/postgres_backend/tests/simple_select.rs @@ -0,0 +1,139 @@ +/// Test postgres_backend_async with tokio_postgres +use once_cell::sync::Lazy; +use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError}; +use pq_proto::{BeMessage, RowDescriptor}; +use std::io::Cursor; +use std::{future, sync::Arc}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_postgres::config::SslMode; +use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::{Config, NoTls, SimpleQueryMessage}; +use tokio_postgres_rustls::MakeRustlsConnect; + +// generate client, server test streams +async fn make_tcp_pair() -> (TcpStream, TcpStream) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let client_stream = TcpStream::connect(addr).await.unwrap(); + let (server_stream, _) = listener.accept().await.unwrap(); + (client_stream, server_stream) +} + +struct TestHandler {} + +#[async_trait::async_trait] +impl Handler for TestHandler { + // return single col 'hey' for any query + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + _query_string: &str, + ) -> Result<(), QueryError> { + pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col( + b"hey", + )]))? + .write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + Ok(()) + } +} + +// test that basic select works +#[tokio::test] +async fn simple_select() { + let (client_sock, server_sock) = make_tcp_pair().await; + + // create and run pgbackend + let pgbackend = + PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation"); + + tokio::spawn(async move { + let mut handler = TestHandler {}; + pgbackend.run(&mut handler, future::pending::<()>).await + }); + + let conf = Config::new(); + let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect"); + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0]; + if let SimpleQueryMessage::Row(row) = first_val { + let first_col = row.get(0).expect("first column"); + assert_eq!(first_col, "hey"); + } else { + panic!("expected SimpleQueryMessage::Row"); + } +} + +static KEY: Lazy = Lazy::new(|| { + let mut cursor = Cursor::new(include_bytes!("key.pem")); + rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()) +}); + +static CERT: Lazy = Lazy::new(|| { + let mut cursor = Cursor::new(include_bytes!("cert.pem")); + rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) +}); + +// test that basic select with ssl works +#[tokio::test] +async fn simple_select_ssl() { + let (client_sock, server_sock) = make_tcp_pair().await; + + let server_cfg = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![CERT.clone()], KEY.clone()) + .unwrap(); + let tls_config = Some(Arc::new(server_cfg)); + let pgbackend = + PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation"); + + tokio::spawn(async move { + let mut handler = TestHandler {}; + pgbackend.run(&mut handler, future::pending::<()>).await + }); + + let client_cfg = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates({ + let mut store = rustls::RootCertStore::empty(); + store.add(&CERT).unwrap(); + store + }) + .with_no_client_auth(); + let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg); + let tls_connect = >::make_tls_connect( + &mut make_tls_connect, + "localhost", + ) + .expect("make_tls_connect"); + + let mut conf = Config::new(); + conf.ssl_mode(SslMode::Require); + let (client, connection) = conf + .connect_raw(client_sock, tls_connect) + .await + .expect("connect"); + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0]; + if let SimpleQueryMessage::Row(row) = first_val { + let first_col = row.get(0).expect("first column"); + assert_eq!(first_col, "hey"); + } else { + panic!("expected SimpleQueryMessage::Row"); + } +} diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index bc90a7a2c1..76b71729ed 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -5,8 +5,8 @@ edition.workspace = true license.workspace = true [dependencies] -anyhow.workspace = true bytes.workspace = true +byteorder.workspace = true pin-project-lite.workspace = true postgres-protocol.workspace = true rand.workspace = true diff --git a/libs/pq_proto/src/framed.rs b/libs/pq_proto/src/framed.rs new file mode 100644 index 0000000000..4b8a03fefc --- /dev/null +++ b/libs/pq_proto/src/framed.rs @@ -0,0 +1,251 @@ +//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from +//! the async stream based on (and buffered with) BytesMut. All functions are +//! cancellation safe. +//! +//! It is similar to what tokio_util::codec::Framed with appropriate codec +//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used +//! separately without using split from futures::stream::StreamExt (which +//! allocates box[1] in polling internally). tokio::io::split is used for splitting +//! instead. Plus we customize error messages more than a single type for all io +//! calls. +//! +//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107 +use bytes::{Buf, BytesMut}; +use std::{ + future::Future, + io::{self, ErrorKind}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; + +use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; + +const INITIAL_CAPACITY: usize = 8 * 1024; + +/// Error on postgres connection: either IO (physical transport error) or +/// protocol violation. +#[derive(thiserror::Error, Debug)] +pub enum ConnectionError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Protocol(#[from] ProtocolError), +} + +impl ConnectionError { + /// Proxy stream.rs uses only io::Error; provide it. + pub fn into_io_error(self) -> io::Error { + match self { + ConnectionError::Io(io) => io, + ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()), + } + } +} + +/// Wraps async io `stream`, providing messages to write/flush + read Postgres +/// messages. +pub struct Framed { + stream: S, + read_buf: BytesMut, + write_buf: BytesMut, +} + +impl Framed { + pub fn new(stream: S) -> Self { + Self { + stream, + read_buf: BytesMut::with_capacity(INITIAL_CAPACITY), + write_buf: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } + + /// Get a shared reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Extract the underlying stream. + pub fn into_inner(self) -> S { + self.stream + } + + /// Return new Framed with stream type transformed by async f, for TLS + /// upgrade. + pub async fn map_stream(self, f: F) -> Result, E> + where + F: FnOnce(S) -> Fut, + Fut: Future>, + { + let stream = f(self.stream).await?; + Ok(Framed { + stream, + read_buf: self.read_buf, + write_buf: self.write_buf, + }) + } +} + +impl Framed { + pub async fn read_startup_message( + &mut self, + ) -> Result, ConnectionError> { + read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await + } + + pub async fn read_message(&mut self) -> Result, ConnectionError> { + read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await + } +} + +impl Framed { + /// Write next message to the output buffer; doesn't flush. + pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + BeMessage::write(&mut self.write_buf, msg) + } + + /// Flush out the buffer. This function is cancellation safe: it can be + /// interrupted and flushing will be continued in the next call. + pub async fn flush(&mut self) -> Result<(), io::Error> { + flush(&mut self.stream, &mut self.write_buf).await + } + + /// Flush out the buffer and shutdown the stream. + pub async fn shutdown(&mut self) -> Result<(), io::Error> { + shutdown(&mut self.stream, &mut self.write_buf).await + } +} + +impl Framed { + /// Split into owned read and write parts. Beware of potential issues with + /// using halves in different tasks on TLS stream: + /// https://github.com/tokio-rs/tls/issues/40 + pub fn split(self) -> (FramedReader>, FramedWriter>) { + let (read_half, write_half) = tokio::io::split(self.stream); + let reader = FramedReader { + stream: read_half, + read_buf: self.read_buf, + }; + let writer = FramedWriter { + stream: write_half, + write_buf: self.write_buf, + }; + (reader, writer) + } + + /// Join read and write parts back. + pub fn unsplit(reader: FramedReader>, writer: FramedWriter>) -> Self { + Self { + stream: reader.stream.unsplit(writer.stream), + read_buf: reader.read_buf, + write_buf: writer.write_buf, + } + } +} + +/// Read-only version of `Framed`. +pub struct FramedReader { + stream: S, + read_buf: BytesMut, +} + +impl FramedReader { + pub async fn read_message(&mut self) -> Result, ConnectionError> { + read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await + } +} + +/// Write-only version of `Framed`. +pub struct FramedWriter { + stream: S, + write_buf: BytesMut, +} + +impl FramedWriter { + /// Get a mut reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } +} + +impl FramedWriter { + /// Write next message to the output buffer; doesn't flush. + pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + BeMessage::write(&mut self.write_buf, msg) + } + + /// Flush out the buffer. This function is cancellation safe: it can be + /// interrupted and flushing will be continued in the next call. + pub async fn flush(&mut self) -> Result<(), io::Error> { + flush(&mut self.stream, &mut self.write_buf).await + } + + /// Flush out the buffer and shutdown the stream. + pub async fn shutdown(&mut self) -> Result<(), io::Error> { + shutdown(&mut self.stream, &mut self.write_buf).await + } +} + +/// Read next message from the stream. Returns Ok(None), if EOF happened and we +/// don't have remaining data in the buffer. This function is cancellation safe: +/// you can drop future which is not yet complete and finalize reading message +/// with the next call. +/// +/// Parametrized to allow reading startup or usual message, having different +/// format. +async fn read_message( + stream: &mut S, + read_buf: &mut BytesMut, + parse: P, +) -> Result, ConnectionError> +where + P: Fn(&mut BytesMut) -> Result, ProtocolError>, +{ + loop { + if let Some(msg) = parse(read_buf)? { + return Ok(Some(msg)); + } + // If we can't build a frame yet, try to read more data and try again. + // Make sure we've got room for at least one byte to read to ensure + // that we don't get a spurious 0 that looks like EOF. + read_buf.reserve(1); + if stream.read_buf(read_buf).await? == 0 { + if read_buf.has_remaining() { + return Err(io::Error::new( + ErrorKind::UnexpectedEof, + "EOF with unprocessed data in the buffer", + ) + .into()); + } else { + return Ok(None); // clean EOF + } + } + } +} + +async fn flush( + stream: &mut S, + write_buf: &mut BytesMut, +) -> Result<(), io::Error> { + while write_buf.has_remaining() { + let bytes_written = stream.write(write_buf.chunk()).await?; + if bytes_written == 0 { + return Err(io::Error::new( + ErrorKind::WriteZero, + "failed to write message", + )); + } + // The advanced part will be garbage collected, likely during shifting + // data left on next attempt to write to buffer when free space is not + // enough. + write_buf.advance(bytes_written); + } + write_buf.clear(); + stream.flush().await +} + +async fn shutdown( + stream: &mut S, + write_buf: &mut BytesMut, +) -> Result<(), io::Error> { + flush(stream, write_buf).await?; + stream.shutdown().await +} diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index b7995c840c..3ebb14de5a 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -2,24 +2,18 @@ //! //! on message formats. -// Tools for calling certain async methods in sync contexts. -pub mod sync; +pub mod framed; -use anyhow::{ensure, Context, Result}; +use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_protocol::PG_EPOCH; use serde::{Deserialize, Serialize}; use std::{ borrow::Cow, collections::HashMap, - fmt, - future::Future, - io::{self, Cursor}, - str, + fmt, io, str, time::{Duration, SystemTime}, }; -use sync::{AsyncishRead, SyncFuture}; -use tokio::io::AsyncReadExt; use tracing::{trace, warn}; pub type Oid = u32; @@ -31,7 +25,6 @@ pub const TEXT_OID: Oid = 25; #[derive(Debug)] pub enum FeMessage { - StartupPacket(FeStartupPacket), // Simple query. Query(Bytes), // Extended query protocol. @@ -191,260 +184,207 @@ pub struct FeExecuteMessage { #[derive(Debug)] pub struct FeCloseMessage; -/// Retry a read on EINTR -/// -/// This runs the enclosed expression, and if it returns -/// Err(io::ErrorKind::Interrupted), retries it. -macro_rules! retry_read { - ( $x:expr ) => { - loop { - match $x { - Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, - res => break res, - } - } - }; -} - -/// An error occured during connection being open. +/// An error occured while parsing or serializing raw stream into Postgres +/// messages. #[derive(thiserror::Error, Debug)] -pub enum ConnectionError { - /// IO error during writing to or reading from the connection socket. - #[error("Socket IO error: {0}")] - Socket(std::io::Error), - /// Invalid packet was received from client +pub enum ProtocolError { + /// Invalid packet was received from the client (e.g. unexpected message + /// type or broken len). #[error("Protocol error: {0}")] Protocol(String), - /// Failed to parse a protocol mesage + /// Failed to parse or, (unlikely), serialize a protocol message. #[error("Message parse error: {0}")] - MessageParse(anyhow::Error), + BadMessage(String), } -impl From for ConnectionError { - fn from(e: anyhow::Error) -> Self { - Self::MessageParse(e) - } -} - -impl ConnectionError { +impl ProtocolError { + /// Proxy stream.rs uses only io::Error; provide it. pub fn into_io_error(self) -> io::Error { - match self { - ConnectionError::Socket(io) => io, - other => io::Error::new(io::ErrorKind::Other, other.to_string()), - } + io::Error::new(io::ErrorKind::Other, self.to_string()) } } impl FeMessage { - /// Read one message from the stream. - /// This function returns `Ok(None)` in case of EOF. - /// One way to handle this properly: + /// Read and parse one message from the `buf` input buffer. If there is at + /// least one valid message, returns it, advancing `buf`; redundant copies + /// are avoided, as thanks to `bytes` crate ptrs in parsed message point + /// directly into the `buf` (processed data is garbage collected after + /// parsed message is dropped). /// - /// ``` - /// # use std::io; - /// # use pq_proto::FeMessage; - /// # - /// # fn process_message(msg: FeMessage) -> anyhow::Result<()> { - /// # Ok(()) - /// # }; - /// # - /// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> { - /// while let Some(msg) = FeMessage::read(stream)? { - /// process_message(msg)?; - /// } + /// Returns None if `buf` doesn't contain enough data for a single message. + /// For efficiency, tries to reserve large enough space in `buf` for the + /// next message in this case to save the repeated calls. /// - /// Ok(()) - /// } - /// ``` - #[inline(never)] - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { - Self::read_fut(&mut AsyncishRead(stream)).wait() - } + /// Returns Error if message is malformed, the only possible ErrorKind is + /// InvalidInput. + // + // Inspired by rust-postgres Message::parse. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { + // Every message contains message type byte and 4 bytes len; can't do + // much without them. + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - /// Read one message from the stream. - /// See documentation for `Self::read`. - pub fn read_fut( - stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { - // We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof. - // SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and - // AsyncReadExt methods of the stream. - SyncFuture::new(async move { - // Each libpq message begins with a message type byte, followed by message length - // If the client closes the connection, return None. But if the client closes the - // connection in the middle of a message, we will return an error. - let tag = match retry_read!(stream.read_u8().await) { - Ok(b) => b, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let tag = buf[0]; + let len = (&buf[1..5]).read_u32::().unwrap(); + if len < 4 { + return Err(ProtocolError::Protocol(format!( + "invalid message length {}", + len + ))); + } - // The message length includes itself, so it better be at least 4. - let len = retry_read!(stream.read_u32().await) - .map_err(ConnectionError::Socket)? - .checked_sub(4) - .ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?; + // length field includes itself, but not message type. + let total_len = len as usize + 1; + if buf.len() < total_len { + // Don't have full message yet. + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - let body = { - let mut buffer = vec![0u8; len as usize]; - stream - .read_exact(&mut buffer) - .await - .map_err(ConnectionError::Socket)?; - Bytes::from(buffer) - }; + // got the message, advance buffer + let mut msg = buf.split_to(total_len).freeze(); + msg.advance(5); // consume message type and len - match tag { - b'Q' => Ok(Some(FeMessage::Query(body))), - b'P' => Ok(Some(FeParseMessage::parse(body)?)), - b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), - b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), - b'B' => Ok(Some(FeBindMessage::parse(body)?)), - b'C' => Ok(Some(FeCloseMessage::parse(body)?)), - b'S' => Ok(Some(FeMessage::Sync)), - b'X' => Ok(Some(FeMessage::Terminate)), - b'd' => Ok(Some(FeMessage::CopyData(body))), - b'c' => Ok(Some(FeMessage::CopyDone)), - b'f' => Ok(Some(FeMessage::CopyFail)), - b'p' => Ok(Some(FeMessage::PasswordMessage(body))), - tag => { - return Err(ConnectionError::Protocol(format!( - "unknown message tag: {tag},'{body:?}'" - ))) - } + match tag { + b'Q' => Ok(Some(FeMessage::Query(msg))), + b'P' => Ok(Some(FeParseMessage::parse(msg)?)), + b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)), + b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)), + b'B' => Ok(Some(FeBindMessage::parse(msg)?)), + b'C' => Ok(Some(FeCloseMessage::parse(msg)?)), + b'S' => Ok(Some(FeMessage::Sync)), + b'X' => Ok(Some(FeMessage::Terminate)), + b'd' => Ok(Some(FeMessage::CopyData(msg))), + b'c' => Ok(Some(FeMessage::CopyDone)), + b'f' => Ok(Some(FeMessage::CopyFail)), + b'p' => Ok(Some(FeMessage::PasswordMessage(msg))), + tag => { + return Err(ProtocolError::Protocol(format!( + "unknown message tag: {tag},'{msg:?}'" + ))) } - }) + } } } impl FeStartupPacket { - /// Read startup message from the stream. - // XXX: It's tempting yet undesirable to accept `stream` by value, - // since such a change will cause user-supplied &mut references to be consumed - pub fn read( - stream: &mut (impl io::Read + Unpin), - ) -> Result, ConnectionError> { - Self::read_fut(&mut AsyncishRead(stream)).wait() - } - - /// Read startup message from the stream. - // XXX: It's tempting yet undesirable to accept `stream` by value, - // since such a change will cause user-supplied &mut references to be consumed - pub fn read_fut( - stream: &mut Reader, - ) -> SyncFuture, ConnectionError>> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { + /// Read and parse startup message from the `buf` input buffer. It is + /// different from [`FeMessage::parse`] because startup messages don't have + /// message type byte; otherwise, its comments apply. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; const CANCEL_REQUEST_CODE: u32 = 5678; const NEGOTIATE_SSL_CODE: u32 = 5679; const NEGOTIATE_GSS_CODE: u32 = 5680; - SyncFuture::new(async move { - // Read length. If the connection is closed before reading anything (or before - // reading 4 bytes, to be precise), return None to indicate that the connection - // was closed. This matches the PostgreSQL server's behavior, which avoids noise - // in the log if the client opens connection but closes it immediately. - let len = match retry_read!(stream.read_u32().await) { - Ok(len) => len as usize, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + // need at least 4 bytes with packet len + if buf.len() < 4 { + let to_read = 4 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - #[allow(clippy::manual_range_contains)] - if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { - return Err(ConnectionError::Protocol(format!( - "invalid message length {len}" + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let len = (&buf[0..4]).read_u32::().unwrap() as usize; + if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { + return Err(ProtocolError::Protocol(format!( + "invalid startup packet message length {}", + len + ))); + } + + if buf.len() < len { + // Don't have full message yet. + let to_read = len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + // got the message, advance buffer + let mut msg = buf.split_to(len).freeze(); + msg.advance(4); // consume len + + let request_code = msg.get_u32(); + let req_hi = request_code >> 16; + let req_lo = request_code & ((1 << 16) - 1); + // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code. + let message = match (req_hi, req_lo) { + (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { + if msg.remaining() != 8 { + return Err(ProtocolError::BadMessage( + "CancelRequest message is malformed, backend PID / secret key missing" + .to_owned(), + )); + } + FeStartupPacket::CancelRequest(CancelKeyData { + backend_pid: msg.get_i32(), + cancel_key: msg.get_i32(), + }) + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { + // Requested upgrade to SSL (aka TLS) + FeStartupPacket::SslRequest + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { + // Requested upgrade to GSSAPI + FeStartupPacket::GssEncRequest + } + (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { + return Err(ProtocolError::Protocol(format!( + "Unrecognized request code {unrecognized_code}" ))); } + // TODO bail if protocol major_version is not 3? + (major_version, minor_version) => { + // StartupMessage - let request_code = - retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?; + // Parse pairs of null-terminated strings (key, value). + // See `postgres: ProcessStartupPacket, build_startup_packet`. + let mut tokens = str::from_utf8(&msg) + .map_err(|_e| { + ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned()) + })? + .strip_suffix('\0') // drop packet's own null + .ok_or_else(|| { + ProtocolError::Protocol( + "StartupMessage params: missing null terminator".to_string(), + ) + })? + .split_terminator('\0'); - // the rest of startup packet are params - let params_len = len - 8; - let mut params_bytes = vec![0u8; params_len]; - stream - .read_exact(params_bytes.as_mut()) - .await - .map_err(ConnectionError::Socket)?; + let mut params = HashMap::new(); + while let Some(name) = tokens.next() { + let value = tokens.next().ok_or_else(|| { + ProtocolError::Protocol( + "StartupMessage params: key without value".to_string(), + ) + })?; - // Parse params depending on request code - let req_hi = request_code >> 16; - let req_lo = request_code & ((1 << 16) - 1); - let message = match (req_hi, req_lo) { - (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { - if params_len != 8 { - return Err(ConnectionError::Protocol( - "expected 8 bytes for CancelRequest params".to_string(), - )); - } - let mut cursor = Cursor::new(params_bytes); - FeStartupPacket::CancelRequest(CancelKeyData { - backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - }) + params.insert(name.to_owned(), value.to_owned()); } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { - // Requested upgrade to SSL (aka TLS) - FeStartupPacket::SslRequest - } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { - // Requested upgrade to GSSAPI - FeStartupPacket::GssEncRequest - } - (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { - return Err(ConnectionError::Protocol(format!( - "Unrecognized request code {unrecognized_code}" - ))); - } - // TODO bail if protocol major_version is not 3? - (major_version, minor_version) => { - // Parse pairs of null-terminated strings (key, value). - // See `postgres: ProcessStartupPacket, build_startup_packet`. - let mut tokens = str::from_utf8(¶ms_bytes) - .context("StartupMessage params: invalid utf-8")? - .strip_suffix('\0') // drop packet's own null - .ok_or_else(|| { - ConnectionError::Protocol( - "StartupMessage params: missing null terminator".to_string(), - ) - })? - .split_terminator('\0'); - let mut params = HashMap::new(); - while let Some(name) = tokens.next() { - let value = tokens.next().ok_or_else(|| { - ConnectionError::Protocol( - "StartupMessage params: key without value".to_string(), - ) - })?; - - params.insert(name.to_owned(), value.to_owned()); - } - - FeStartupPacket::StartupMessage { - major_version, - minor_version, - params: StartupMessageParams { params }, - } + FeStartupPacket::StartupMessage { + major_version, + minor_version, + params: StartupMessageParams { params }, } - }; - - Ok(Some(FeMessage::StartupPacket(message))) - }) + } + }; + Ok(Some(message)) } } impl FeParseMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { // FIXME: the rust-postgres driver uses a named prepared statement // for copy_out(). We're not prepared to handle that correctly. For // now, just ignore the statement name, assuming that the client never @@ -452,55 +392,82 @@ impl FeParseMessage { let _pstmt_name = read_cstr(&mut buf)?; let query_string = read_cstr(&mut buf)?; + if buf.remaining() < 2 { + return Err(ProtocolError::BadMessage( + "Parse message is malformed, nparams missing".to_string(), + )); + } let nparams = buf.get_i16(); - ensure!(nparams == 0, "query params not implemented"); + if nparams != 0 { + return Err(ProtocolError::BadMessage( + "query params not implemented".to_string(), + )); + } Ok(FeMessage::Parse(FeParseMessage { query_string })) } } impl FeDescribeMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let kind = buf.get_u8(); let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - ensure!( - kind == b'S', - "only prepared statemement Describe is implemented" - ); + if kind != b'S' { + return Err(ProtocolError::BadMessage( + "only prepared statemement Describe is implemented".to_string(), + )); + } Ok(FeMessage::Describe(FeDescribeMessage { kind })) } } impl FeExecuteMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let portal_name = read_cstr(&mut buf)?; + if buf.remaining() < 4 { + return Err(ProtocolError::BadMessage( + "FeExecuteMessage message is malformed, maxrows missing".to_string(), + )); + } let maxrows = buf.get_i32(); - ensure!(portal_name.is_empty(), "named portals not implemented"); - ensure!(maxrows == 0, "row limit in Execute message not implemented"); + if !portal_name.is_empty() { + return Err(ProtocolError::BadMessage( + "named portals not implemented".to_string(), + )); + } + if maxrows != 0 { + return Err(ProtocolError::BadMessage( + "row limit in Execute message not implemented".to_string(), + )); + } Ok(FeMessage::Execute(FeExecuteMessage { maxrows })) } } impl FeBindMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let portal_name = read_cstr(&mut buf)?; let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - ensure!(portal_name.is_empty(), "named portals not implemented"); + if !portal_name.is_empty() { + return Err(ProtocolError::BadMessage( + "named portals not implemented".to_string(), + )); + } Ok(FeMessage::Bind(FeBindMessage)) } } impl FeCloseMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let _kind = buf.get_u8(); let _pstmt_or_portal_name = read_cstr(&mut buf)?; @@ -529,6 +496,7 @@ pub enum BeMessage<'a> { CloseComplete, // None means column is NULL DataRow(&'a [Option<&'a [u8]>]), + // None errcode means internal_error will be sent. ErrorResponse(&'a str, Option<&'a [u8; 5]>), /// Single byte - used in response to SSLRequest/GSSENCRequest. EncryptionResponse(bool), @@ -559,6 +527,11 @@ impl<'a> BeMessage<'a> { value: b"UTF8", }; + pub const INTEGER_DATETIMES: Self = Self::ParameterStatus { + name: b"integer_datetimes", + value: b"on", + }; + /// Build a [`BeMessage::ParameterStatus`] holding the server version. pub fn server_version(version: &'a str) -> Self { Self::ParameterStatus { @@ -637,7 +610,7 @@ impl RowDescriptor<'_> { #[derive(Debug)] pub struct XLogDataBody<'a> { pub wal_start: u64, - pub wal_end: u64, + pub wal_end: u64, // current end of WAL on the server pub timestamp: i64, pub data: &'a [u8], } @@ -677,12 +650,11 @@ fn write_body(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R { } /// Safe write of s into buf as cstring (String in the protocol). -fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { +fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> { let bytes = s.as_ref(); if bytes.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", + return Err(ProtocolError::BadMessage( + "string contains embedded null".to_owned(), )); } buf.put_slice(bytes); @@ -690,22 +662,27 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { Ok(()) } -fn read_cstr(buf: &mut Bytes) -> anyhow::Result { - let pos = buf.iter().position(|x| *x == 0); - let result = buf.split_to(pos.context("missing terminator")?); +/// Read cstring from buf, advancing it. +fn read_cstr(buf: &mut Bytes) -> Result { + let pos = buf + .iter() + .position(|x| *x == 0) + .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?; + let result = buf.split_to(pos); buf.advance(1); // drop the null terminator Ok(result) } pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000"; +pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000"; impl<'a> BeMessage<'a> { - /// Write message to the given buf. - // Unlike the reading side, we use BytesMut - // here as msg len precedes its body and it is handy to write it down first - // and then fill the length. With Write we would have to either calc it - // manually or have one more buffer. - pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> { + /// Serialize `message` to the given `buf`. + /// Apart from smart memory managemet, BytesMut is good here as msg len + /// precedes its body and it is handy to write it down first and then fill + /// the length. With Write we would have to either calc it manually or have + /// one more buffer. + pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> { match message { BeMessage::AuthenticationOk => { buf.put_u8(b'R'); @@ -750,7 +727,7 @@ impl<'a> BeMessage<'a> { buf.put_slice(extra); } } - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -841,7 +818,7 @@ impl<'a> BeMessage<'a> { BeMessage::ErrorResponse(error_msg, pg_error_code) => { // 'E' signalizes ErrorResponse messages buf.put_u8(b'E'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_u8(b'S'); // severity buf.put_slice(b"ERROR\0"); @@ -854,7 +831,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg, buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -866,7 +843,7 @@ impl<'a> BeMessage<'a> { // 'N' signalizes NoticeResponse messages buf.put_u8(b'N'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_u8(b'S'); // severity buf.put_slice(b"NOTICE\0"); @@ -877,7 +854,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg.as_bytes(), buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -921,7 +898,7 @@ impl<'a> BeMessage<'a> { BeMessage::RowDescription(rows) => { buf.put_u8(b'T'); - write_body(buf, |buf| { + write_body(buf, |buf| -> Result<(), ProtocolError> { buf.put_i16(rows.len() as i16); // # of fields for row in rows.iter() { write_cstr(row.name, buf)?; @@ -932,7 +909,7 @@ impl<'a> BeMessage<'a> { buf.put_i32(-1); /* typmod */ buf.put_i16(0); /* format code */ } - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -999,7 +976,7 @@ impl ReplicationFeedback { // null-terminated string - key, // uint32 - value length in bytes // value itself - pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> { + pub fn serialize(&self, buf: &mut BytesMut) { buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys buf.put_slice(b"current_timeline_size\0"); buf.put_i32(8); @@ -1024,7 +1001,6 @@ impl ReplicationFeedback { buf.put_slice(b"ps_replytime\0"); buf.put_i32(8); buf.put_i64(timestamp); - Ok(()) } // Deserialize ReplicationFeedback message @@ -1092,7 +1068,7 @@ mod tests { // because it is rounded up to microseconds during serialization. rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000); let mut data = BytesMut::new(); - rf.serialize(&mut data).unwrap(); + rf.serialize(&mut data); let rf_parsed = ReplicationFeedback::parse(data.freeze()); assert_eq!(rf, rf_parsed); @@ -1107,7 +1083,7 @@ mod tests { // because it is rounded up to microseconds during serialization. rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000); let mut data = BytesMut::new(); - rf.serialize(&mut data).unwrap(); + rf.serialize(&mut data); // Add an extra field to the buffer and adjust number of keys if let Some(first) = data.first_mut() { @@ -1149,15 +1125,6 @@ mod tests { let params = make_params("foo\\ bar \\ \\\\ baz\\ lol"); assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]); } - - // Make sure that `read` is sync/async callable - async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) { - let _ = FeMessage::read(&mut [].as_ref()); - let _ = FeMessage::read_fut(stream).await; - - let _ = FeStartupPacket::read(&mut [].as_ref()); - let _ = FeStartupPacket::read_fut(stream).await; - } } fn terminate_code(code: &[u8; 5]) -> [u8; 6] { diff --git a/libs/pq_proto/src/sync.rs b/libs/pq_proto/src/sync.rs deleted file mode 100644 index b7ff1fb70b..0000000000 --- a/libs/pq_proto/src/sync.rs +++ /dev/null @@ -1,179 +0,0 @@ -use pin_project_lite::pin_project; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::{io, task}; - -pin_project! { - /// We use this future to mark certain methods - /// as callable in both sync and async modes. - #[repr(transparent)] - pub struct SyncFuture { - #[pin] - inner: T, - _marker: PhantomData, - } -} - -/// This wrapper lets us synchronously wait for inner future's completion -/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**. -/// For instance, `S` may be substituted with types implementing -/// [`tokio::io::AsyncRead`], but it's not the only viable option. -impl SyncFuture { - /// NOTE: caller should carefully pick a type for `S`, - /// because we don't want to enable [`SyncFuture::wait`] when - /// it's in fact impossible to run the future synchronously. - /// Violation of this contract will not cause UB, but - /// panics and async event loop freezes won't please you. - /// - /// Example: - /// - /// ``` - /// # use pq_proto::sync::SyncFuture; - /// # use std::future::Future; - /// # use tokio::io::AsyncReadExt; - /// # - /// // Parse a pair of numbers from a stream - /// pub fn parse_pair( - /// stream: &mut Reader, - /// ) -> SyncFuture> + '_> - /// where - /// Reader: tokio::io::AsyncRead + Unpin, - /// { - /// // If `Reader` is a `SyncProof`, this will give caller - /// // an opportunity to use `SyncFuture::wait`, because - /// // `.await` will always result in `Poll::Ready`. - /// SyncFuture::new(async move { - /// let x = stream.read_u32().await?; - /// let y = stream.read_u64().await?; - /// Ok((x, y)) - /// }) - /// } - /// ``` - pub fn new(inner: T) -> Self { - Self { - inner, - _marker: PhantomData, - } - } -} - -impl Future for SyncFuture { - type Output = T::Output; - - /// In async code, [`SyncFuture`] behaves like a regular wrapper. - #[inline(always)] - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { - self.project().inner.poll(cx) - } -} - -/// Postulates that we can call [`SyncFuture::wait`]. -/// If implementer is also a [`Future`], it should always -/// return [`task::Poll::Ready`] from [`Future::poll`]. -/// -/// Each implementation should document which futures -/// specifically are being declared sync-proof. -pub trait SyncPostulate {} - -impl SyncPostulate for &T {} -impl SyncPostulate for &mut T {} - -impl SyncFuture { - /// Synchronously wait for future completion. - pub fn wait(mut self) -> T::Output { - const RAW_WAKER: task::RawWaker = task::RawWaker::new( - std::ptr::null(), - &task::RawWakerVTable::new( - |_| RAW_WAKER, - |_| panic!("SyncFuture: failed to wake"), - |_| panic!("SyncFuture: failed to wake by ref"), - |_| { /* drop is no-op */ }, - ), - ); - - // SAFETY: We never move `self` during this call; - // furthermore, it will be dropped in the end regardless of panics - let this = unsafe { Pin::new_unchecked(&mut self) }; - - // SAFETY: This waker doesn't do anything apart from panicking - let waker = unsafe { task::Waker::from_raw(RAW_WAKER) }; - let context = &mut task::Context::from_waker(&waker); - - match this.poll(context) { - task::Poll::Ready(res) => res, - _ => panic!("SyncFuture: unexpected pending!"), - } - } -} - -/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`], -/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`]. -/// NOTE: you **should not** use this in async code. -#[repr(transparent)] -pub struct AsyncishRead(pub T); - -/// This lets us call [`SyncFuture, _>::wait`], -/// and allows the future to await on any of the [`AsyncRead`] -/// and [`AsyncReadExt`] methods on `AsyncishRead`. -impl SyncPostulate for AsyncishRead {} - -impl tokio::io::AsyncRead for AsyncishRead { - #[inline(always)] - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> task::Poll> { - task::Poll::Ready( - // `Read::read` will block, meaning we don't need a real event loop! - self.0 - .read(buf.initialize_unfilled()) - .map(|sz| buf.advance(sz)), - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - // async helper(stream: &mut impl AsyncRead) -> io::Result - fn bytes_add( - stream: &mut Reader, - ) -> SyncFuture> + '_> - where - Reader: tokio::io::AsyncRead + Unpin, - { - SyncFuture::new(async move { - let a = stream.read_u32().await?; - let b = stream.read_u32().await?; - Ok(a + b) - }) - } - - #[test] - fn test_sync() { - let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat(); - let res = bytes_add(&mut AsyncishRead(&mut &bytes[..])) - .wait() - .unwrap(); - assert_eq!(res, 300); - } - - // We need a single-threaded executor for this test - #[tokio::test(flavor = "current_thread")] - async fn test_async() { - let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap(); - - let write = async move { - tx.write_u32(100).await?; - tx.write_u32(200).await?; - Ok(()) - }; - - let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap(); - assert_eq!(res, 300); - } -} diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 1091a8bd5c..901f849801 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static { } pub struct Download { - pub download_stream: Pin>, + pub download_stream: Pin>, /// Extra key-value data, associated with the current remote file. pub metadata: Option, } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 6acdb6fa53..84b82472c6 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -12,42 +12,38 @@ anyhow.workspace = true bincode.workspace = true bytes.workspace = true heapless.workspace = true +hex = { workspace = true, features = ["serde"] } hyper = { workspace = true, features = ["full"] } futures = { workspace = true} +jsonwebtoken.workspace = true +nix.workspace = true +once_cell.workspace = true routerify.workspace = true serde.workspace = true serde_json.workspace = true +signal-hook.workspace = true thiserror.workspace = true tokio.workspace = true tokio-rustls.workspace = true tracing.workspace = true tracing-subscriber = { workspace = true, features = ["json"] } -nix.workspace = true -signal-hook.workspace = true rand.workspace = true -jsonwebtoken.workspace = true -hex = { workspace = true, features = ["serde"] } rustls.workspace = true -rustls-split.workspace = true -git-version.workspace = true serde_with.workspace = true -once_cell.workspace = true strum.workspace = true strum_macros.workspace = true - -metrics.workspace = true -pq_proto.workspace = true - -workspace_hack.workspace = true url.workspace = true uuid = { version = "1.2", features = ["v4", "serde"] } + +metrics.workspace = true +workspace_hack.workspace = true + [dev-dependencies] byteorder.workspace = true bytes.workspace = true +criterion.workspace = true hex-literal.workspace = true tempfile.workspace = true -criterion.workspace = true -rustls-pemfile.workspace = true [[bench]] name = "benchmarks" diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 9ddd702c72..acb5273943 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -13,8 +13,6 @@ pub mod simple_rcu; pub mod vec_map; pub mod bin_ser; -pub mod postgres_backend; -pub mod postgres_backend_async; // helper functions for creating and fsyncing pub mod crashsafe; @@ -27,9 +25,6 @@ pub mod id; // http endpoint utils pub mod http; -// socket splitting utils -pub mod sock_split; - // common log initialisation routine pub mod logging; diff --git a/libs/utils/src/postgres_backend.rs b/libs/utils/src/postgres_backend.rs deleted file mode 100644 index f3e3835bda..0000000000 --- a/libs/utils/src/postgres_backend.rs +++ /dev/null @@ -1,485 +0,0 @@ -//! Server-side synchronous Postgres connection, as limited as we need. -//! To use, create PostgresBackend and run() it, passing the Handler -//! implementation determining how to process the queries. Currently its API -//! is rather narrow, but we can extend it once required. - -use crate::postgres_backend_async::{log_query_error, short_error, QueryError}; -use crate::sock_split::{BidiStream, ReadStream, WriteStream}; -use anyhow::Context; -use bytes::{Bytes, BytesMut}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket}; -use serde::{Deserialize, Serialize}; -use std::fmt; -use std::io::{self, Write}; -use std::net::{Shutdown, SocketAddr, TcpStream}; -use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; -use tracing::*; - -pub trait Handler { - /// Handle single query. - /// postgres_backend will issue ReadyForQuery after calling this (this - /// might be not what we want after CopyData streaming, but currently we don't - /// care). - fn process_query( - &mut self, - pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError>; - - /// Called on startup packet receival, allows to process params. - /// - /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users - /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow - /// to override whole init logic in implementations. - fn startup( - &mut self, - _pgb: &mut PostgresBackend, - _sm: &FeStartupPacket, - ) -> Result<(), QueryError> { - Ok(()) - } - - /// Check auth jwt - fn check_auth_jwt( - &mut self, - _pgb: &mut PostgresBackend, - _jwt_response: &[u8], - ) -> Result<(), QueryError> { - Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) - } - - fn is_shutdown_requested(&self) -> bool { - false - } -} - -/// PostgresBackend protocol state. -/// XXX: The order of the constructors matters. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum ProtoState { - Initialization, - Encrypted, - Authentication, - Established, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] -pub enum AuthType { - Trust, - // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT - NeonJWT, -} - -impl FromStr for AuthType { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s { - "Trust" => Ok(Self::Trust), - "NeonJWT" => Ok(Self::NeonJWT), - _ => anyhow::bail!("invalid value \"{s}\" for auth type"), - } - } -} - -impl fmt::Display for AuthType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - AuthType::Trust => "Trust", - AuthType::NeonJWT => "NeonJWT", - }) - } -} - -#[derive(Clone, Copy)] -pub enum ProcessMsgResult { - Continue, - Break, -} - -/// Always-writeable sock_split stream. -/// May not be readable. See [`PostgresBackend::take_stream_in`] -pub enum Stream { - Bidirectional(BidiStream), - WriteOnly(WriteStream), -} - -impl Stream { - fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how), - Self::WriteOnly(write_stream) => write_stream.shutdown(how), - } - } -} - -impl io::Write for Stream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.write(buf), - Self::WriteOnly(write_stream) => write_stream.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Bidirectional(bidi_stream) => bidi_stream.flush(), - Self::WriteOnly(write_stream) => write_stream.flush(), - } - } -} - -pub struct PostgresBackend { - stream: Option, - // Output buffer. c.f. BeMessage::write why we are using BytesMut here. - buf_out: BytesMut, - - pub state: ProtoState, - - auth_type: AuthType, - - peer_addr: SocketAddr, - pub tls_config: Option>, -} - -pub fn query_from_cstring(query_string: Bytes) -> Vec { - let mut query_string = query_string.to_vec(); - if let Some(ch) = query_string.last() { - if *ch == 0 { - query_string.pop(); - } - } - query_string -} - -// Helper function for socket read loops -pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool { - for cause in error.chain() { - if let Some(io_error) = cause.downcast_ref::() { - if io_error.kind() == std::io::ErrorKind::WouldBlock { - return true; - } - } - } - false -} - -// Cast a byte slice to a string slice, dropping null terminator if there's one. -fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { - let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); - std::str::from_utf8(without_null).map_err(|e| e.into()) -} - -impl PostgresBackend { - pub fn new( - socket: TcpStream, - auth_type: AuthType, - tls_config: Option>, - set_read_timeout: bool, - ) -> io::Result { - let peer_addr = socket.peer_addr()?; - if set_read_timeout { - socket - .set_read_timeout(Some(Duration::from_secs(5))) - .unwrap(); - } - - Ok(Self { - stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))), - buf_out: BytesMut::with_capacity(10 * 1024), - state: ProtoState::Initialization, - auth_type, - tls_config, - peer_addr, - }) - } - - pub fn into_stream(self) -> Stream { - self.stream.unwrap() - } - - /// Get direct reference (into the Option) to the read stream. - fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> { - match &mut self.stream { - Some(Stream::Bidirectional(stream)) => Ok(stream), - _ => anyhow::bail!("reader taken"), - } - } - - pub fn get_peer_addr(&self) -> &SocketAddr { - &self.peer_addr - } - - pub fn take_stream_in(&mut self) -> Option { - let stream = self.stream.take(); - match stream { - Some(Stream::Bidirectional(bidi_stream)) => { - let (read, write) = bidi_stream.split(); - self.stream = Some(Stream::WriteOnly(write)); - Some(read) - } - stream => { - self.stream = stream; - None - } - } - } - - /// Read full message or return None if connection is closed. - pub fn read_message(&mut self) -> Result, QueryError> { - let (state, stream) = (self.state, self.get_stream_in()?); - - use ProtoState::*; - match state { - Initialization | Encrypted => FeStartupPacket::read(stream), - Authentication | Established => FeMessage::read(stream), - } - .map_err(QueryError::from) - } - - /// Write message into internal output buffer. - pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buf_out, message)?; - Ok(self) - } - - /// Flush output buffer into the socket. - pub fn flush(&mut self) -> io::Result<&mut Self> { - let stream = self.stream.as_mut().unwrap(); - stream.write_all(&self.buf_out)?; - self.buf_out.clear(); - Ok(self) - } - - /// Write message into internal buffer and flush it. - pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush() - } - - // Wrapper for run_message_loop() that shuts down socket when we are done - pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> { - let ret = self.run_message_loop(handler); - if let Some(stream) = self.stream.as_mut() { - let _ = stream.shutdown(Shutdown::Both); - } - ret - } - - fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> { - trace!("postgres backend to {:?} started", self.peer_addr); - - let mut unnamed_query_string = Bytes::new(); - - while !handler.is_shutdown_requested() { - match self.read_message() { - Ok(message) => { - if let Some(msg) = message { - trace!("got message {msg:?}"); - - match self.process_message(handler, msg, &mut unnamed_query_string)? { - ProcessMsgResult::Continue => continue, - ProcessMsgResult::Break => break, - } - } else { - break; - } - } - Err(e) => { - if let QueryError::Other(e) = &e { - if is_socket_read_timed_out(e) { - continue; - } - } - return Err(e); - } - } - } - - trace!("postgres backend to {:?} exited", self.peer_addr); - Ok(()) - } - - pub fn start_tls(&mut self) -> anyhow::Result<()> { - match self.stream.take() { - Some(Stream::Bidirectional(bidi_stream)) => { - let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?; - self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?)); - Ok(()) - } - stream => { - self.stream = stream; - anyhow::bail!("can't start TLs without bidi stream"); - } - } - } - - fn process_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - unnamed_query_string: &mut Bytes, - ) -> Result { - // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth - // TODO: change that to proper top-level match of protocol state with separate message handling for each state - if self.state < ProtoState::Established - && !matches!( - msg, - FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_) - ) - { - return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); - } - - let have_tls = self.tls_config.is_some(); - match msg { - FeMessage::StartupPacket(m) => { - trace!("got startup message {m:?}"); - - match m { - FeStartupPacket::SslRequest => { - debug!("SSL requested"); - - self.write_message(&BeMessage::EncryptionResponse(have_tls))?; - if have_tls { - self.start_tls()?; - self.state = ProtoState::Encrypted; - } - } - FeStartupPacket::GssEncRequest => { - debug!("GSS requested"); - self.write_message(&BeMessage::EncryptionResponse(false))?; - } - FeStartupPacket::StartupMessage { .. } => { - if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message(&BeMessage::ErrorResponse( - "must connect with TLS", - None, - ))?; - return Err(QueryError::Other(anyhow::anyhow!( - "client did not connect with TLS" - ))); - } - - // NB: startup() may change self.auth_type -- we are using that in proxy code - // to bypass auth for new users. - handler.startup(self, &m)?; - - match self.auth_type { - AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - // The async python driver requires a valid server_version - .write_message_noflush(&BeMessage::server_version("14.1"))? - .write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - AuthType::NeonJWT => { - self.write_message(&BeMessage::AuthenticationCleartextPassword)?; - self.state = ProtoState::Authentication; - } - } - } - FeStartupPacket::CancelRequest { .. } => { - return Ok(ProcessMsgResult::Break); - } - } - } - - FeMessage::PasswordMessage(m) => { - trace!("got password message '{:?}'", m); - - assert!(self.state == ProtoState::Authentication); - - match self.auth_type { - AuthType::Trust => unreachable!(), - AuthType::NeonJWT => { - let (_, jwt_response) = m.split_last().context("protocol violation")?; - - if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - return Err(e); - } - } - } - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - .write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - - FeMessage::Query(body) => { - // remove null terminator - let query_string = cstr_to_str(&body)?; - - trace!("got query {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string) { - log_query_error(query_string, &e); - let short_error = short_error(&e); - self.write_message_noflush(&BeMessage::ErrorResponse( - &short_error, - Some(e.pg_error_code()), - ))?; - } - self.write_message(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Parse(m) => { - *unnamed_query_string = m.query_string; - self.write_message(&BeMessage::ParseComplete)?; - } - - FeMessage::Describe(_) => { - self.write_message_noflush(&BeMessage::ParameterDescription)? - .write_message(&BeMessage::NoData)?; - } - - FeMessage::Bind(_) => { - self.write_message(&BeMessage::BindComplete)?; - } - - FeMessage::Close(_) => { - self.write_message(&BeMessage::CloseComplete)?; - } - - FeMessage::Execute(_) => { - let query_string = cstr_to_str(unnamed_query_string)?; - trace!("got execute {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string) { - log_query_error(query_string, &e); - self.write_message(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - } - // NOTE there is no ReadyForQuery message. This handler is used - // for basebackup and it uses CopyOut which doesn't require - // ReadyForQuery message and backend just switches back to - // processing mode after sending CopyDone or ErrorResponse. - } - - FeMessage::Sync => { - self.write_message(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Terminate => { - return Ok(ProcessMsgResult::Break); - } - - // We prefer explicit pattern matching to wildcards, because - // this helps us spot the places where new variants are missing - FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message type: {msg:?}" - ))); - } - } - - Ok(ProcessMsgResult::Continue) - } -} diff --git a/libs/utils/src/postgres_backend_async.rs b/libs/utils/src/postgres_backend_async.rs deleted file mode 100644 index 442b06ed01..0000000000 --- a/libs/utils/src/postgres_backend_async.rs +++ /dev/null @@ -1,636 +0,0 @@ -//! Server-side asynchronous Postgres connection, as limited as we need. -//! To use, create PostgresBackend and run() it, passing the Handler -//! implementation determining how to process the queries. Currently its API -//! is rather narrow, but we can extend it once required. - -use crate::postgres_backend::AuthType; -use anyhow::Context; -use bytes::{Buf, Bytes, BytesMut}; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR}; -use std::io; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::task::Poll; -use std::{future::Future, task::ready}; -use tracing::{debug, error, info, trace}; - -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; -use tokio_rustls::TlsAcceptor; - -pub fn is_expected_io_error(e: &io::Error) -> bool { - use io::ErrorKind::*; - matches!( - e.kind(), - ConnectionRefused | ConnectionAborted | ConnectionReset - ) -} - -/// An error, occurred during query processing: -/// either during the connection ([`ConnectionError`]) or before/after it. -#[derive(thiserror::Error, Debug)] -pub enum QueryError { - /// The connection was lost while processing the query. - #[error(transparent)] - Disconnected(#[from] ConnectionError), - /// Some other error - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -impl From for QueryError { - fn from(e: io::Error) -> Self { - Self::Disconnected(ConnectionError::Socket(e)) - } -} - -impl QueryError { - pub fn pg_error_code(&self) -> &'static [u8; 5] { - match self { - Self::Disconnected(_) => b"08006", // connection failure - Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error - } - } -} - -#[async_trait::async_trait] -pub trait Handler { - /// Handle single query. - /// postgres_backend will issue ReadyForQuery after calling this (this - /// might be not what we want after CopyData streaming, but currently we don't - /// care). - async fn process_query( - &mut self, - pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError>; - - /// Called on startup packet receival, allows to process params. - /// - /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users - /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow - /// to override whole init logic in implementations. - fn startup( - &mut self, - _pgb: &mut PostgresBackend, - _sm: &FeStartupPacket, - ) -> Result<(), QueryError> { - Ok(()) - } - - /// Check auth jwt - fn check_auth_jwt( - &mut self, - _pgb: &mut PostgresBackend, - _jwt_response: &[u8], - ) -> Result<(), QueryError> { - Err(QueryError::Other(anyhow::anyhow!("JWT auth failed"))) - } -} - -/// PostgresBackend protocol state. -/// XXX: The order of the constructors matters. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum ProtoState { - Initialization, - Encrypted, - Authentication, - Established, - Closed, -} - -#[derive(Clone, Copy)] -pub enum ProcessMsgResult { - Continue, - Break, -} - -/// Always-writeable sock_split stream. -/// May not be readable. See [`PostgresBackend::take_stream_in`] -pub enum Stream { - Unencrypted(BufReader), - Tls(Box>>), - Broken, -} - -impl AsyncWrite for Stream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf), - Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), - Self::Broken => unreachable!(), - } - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx), - Self::Tls(stream) => Pin::new(stream).poll_flush(cx), - Self::Broken => unreachable!(), - } - } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx), - Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), - Self::Broken => unreachable!(), - } - } -} -impl AsyncRead for Stream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - match self.get_mut() { - Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf), - Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), - Self::Broken => unreachable!(), - } - } -} - -pub struct PostgresBackend { - stream: Stream, - - // Output buffer. c.f. BeMessage::write why we are using BytesMut here. - // The data between 0 and "current position" as tracked by the bytes::Buf - // implementation of BytesMut, have already been written. - buf_out: BytesMut, - - pub state: ProtoState, - - auth_type: AuthType, - - peer_addr: SocketAddr, - pub tls_config: Option>, -} - -pub fn query_from_cstring(query_string: Bytes) -> Vec { - let mut query_string = query_string.to_vec(); - if let Some(ch) = query_string.last() { - if *ch == 0 { - query_string.pop(); - } - } - query_string -} - -// Cast a byte slice to a string slice, dropping null terminator if there's one. -fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> { - let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes); - std::str::from_utf8(without_null).map_err(|e| e.into()) -} - -impl PostgresBackend { - pub fn new( - socket: tokio::net::TcpStream, - auth_type: AuthType, - tls_config: Option>, - ) -> io::Result { - let peer_addr = socket.peer_addr()?; - - Ok(Self { - stream: Stream::Unencrypted(BufReader::new(socket)), - buf_out: BytesMut::with_capacity(10 * 1024), - state: ProtoState::Initialization, - auth_type, - tls_config, - peer_addr, - }) - } - - pub fn get_peer_addr(&self) -> &SocketAddr { - &self.peer_addr - } - - /// Read full message or return None if connection is closed. - pub async fn read_message(&mut self) -> Result, QueryError> { - use ProtoState::*; - match self.state { - Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await, - Authentication | Established => FeMessage::read_fut(&mut self.stream).await, - Closed => Ok(None), - } - .map_err(QueryError::from) - } - - /// Flush output buffer into the socket. - pub async fn flush(&mut self) -> io::Result<()> { - while self.buf_out.has_remaining() { - let bytes_written = self.stream.write(self.buf_out.chunk()).await?; - self.buf_out.advance(bytes_written); - } - self.buf_out.clear(); - Ok(()) - } - - /// Write message into internal output buffer. - pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buf_out, message)?; - Ok(self) - } - - /// Returns an AsyncWrite implementation that wraps all the data written - /// to it in CopyData messages, and writes them to the connection - /// - /// The caller is responsible for sending CopyOutResponse and CopyDone messages. - pub fn copyout_writer(&mut self) -> CopyDataWriter { - CopyDataWriter { pgb: self } - } - - /// A polling function that tries to write all the data from 'buf_out' to the - /// underlying stream. - fn poll_write_buf( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - while self.buf_out.has_remaining() { - match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) { - Ok(bytes_written) => self.buf_out.advance(bytes_written), - Err(err) => return Poll::Ready(Err(err)), - } - } - Poll::Ready(Ok(())) - } - - fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_flush(cx) - } - - // Wrapper for run_message_loop() that shuts down socket when we are done - pub async fn run( - mut self, - handler: &mut impl Handler, - shutdown_watcher: F, - ) -> Result<(), QueryError> - where - F: Fn() -> S, - S: Future, - { - let ret = self.run_message_loop(handler, shutdown_watcher).await; - let _ = self.stream.shutdown(); - ret - } - - async fn run_message_loop( - &mut self, - handler: &mut impl Handler, - shutdown_watcher: F, - ) -> Result<(), QueryError> - where - F: Fn() -> S, - S: Future, - { - trace!("postgres backend to {:?} started", self.peer_addr); - - tokio::select!( - biased; - - _ = shutdown_watcher() => { - // We were requested to shut down. - tracing::info!("shutdown request received during handshake"); - return Ok(()) - }, - - result = async { - while self.state < ProtoState::Established { - if let Some(msg) = self.read_message().await? { - trace!("got message {msg:?} during handshake"); - - match self.process_handshake_message(handler, msg).await? { - ProcessMsgResult::Continue => { - self.flush().await?; - continue; - } - ProcessMsgResult::Break => { - trace!("postgres backend to {:?} exited during handshake", self.peer_addr); - return Ok(()); - } - } - } else { - trace!("postgres backend to {:?} exited during handshake", self.peer_addr); - return Ok(()); - } - } - Ok::<(), QueryError>(()) - } => { - // Handshake complete. - result?; - } - ); - - // Authentication completed - let mut query_string = Bytes::new(); - while let Some(msg) = tokio::select!( - biased; - _ = shutdown_watcher() => { - // We were requested to shut down. - tracing::info!("shutdown request received in run_message_loop"); - Ok(None) - }, - msg = self.read_message() => { msg }, - )? { - trace!("got message {:?}", msg); - - let result = self.process_message(handler, msg, &mut query_string).await; - self.flush().await?; - match result? { - ProcessMsgResult::Continue => { - self.flush().await?; - continue; - } - ProcessMsgResult::Break => break, - } - } - - trace!("postgres backend to {:?} exited", self.peer_addr); - Ok(()) - } - - async fn start_tls(&mut self) -> anyhow::Result<()> { - if let Stream::Unencrypted(plain_stream) = - std::mem::replace(&mut self.stream, Stream::Broken) - { - let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap()); - let tls_stream = acceptor.accept(plain_stream).await?; - - self.stream = Stream::Tls(Box::new(tls_stream)); - return Ok(()); - }; - anyhow::bail!("TLS already started"); - } - - async fn process_handshake_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - ) -> Result { - assert!(self.state < ProtoState::Established); - let have_tls = self.tls_config.is_some(); - match msg { - FeMessage::StartupPacket(m) => { - trace!("got startup message {m:?}"); - - match m { - FeStartupPacket::SslRequest => { - debug!("SSL requested"); - - self.write_message_noflush(&BeMessage::EncryptionResponse(have_tls))?; - if have_tls { - self.start_tls().await?; - self.state = ProtoState::Encrypted; - } - } - FeStartupPacket::GssEncRequest => { - debug!("GSS requested"); - self.write_message_noflush(&BeMessage::EncryptionResponse(false))?; - } - FeStartupPacket::StartupMessage { .. } => { - if have_tls && !matches!(self.state, ProtoState::Encrypted) { - self.write_message_noflush(&BeMessage::ErrorResponse( - "must connect with TLS", - None, - ))?; - return Err(QueryError::Other(anyhow::anyhow!( - "client did not connect with TLS" - ))); - } - - // NB: startup() may change self.auth_type -- we are using that in proxy code - // to bypass auth for new users. - handler.startup(self, &m)?; - - match self.auth_type { - AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - // The async python driver requires a valid server_version - .write_message_noflush(&BeMessage::server_version("14.1"))? - .write_message_noflush(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - AuthType::NeonJWT => { - self.write_message_noflush( - &BeMessage::AuthenticationCleartextPassword, - )?; - self.state = ProtoState::Authentication; - } - } - } - FeStartupPacket::CancelRequest { .. } => { - self.state = ProtoState::Closed; - return Ok(ProcessMsgResult::Break); - } - } - } - - FeMessage::PasswordMessage(m) => { - trace!("got password message '{:?}'", m); - - assert!(self.state == ProtoState::Authentication); - - match self.auth_type { - AuthType::Trust => unreachable!(), - AuthType::NeonJWT => { - let (_, jwt_response) = m.split_last().context("protocol violation")?; - - if let Err(e) = handler.check_auth_jwt(self, jwt_response) { - self.write_message_noflush(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - return Err(e); - } - } - } - self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeMessage::CLIENT_ENCODING)? - .write_message_noflush(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - - _ => { - self.state = ProtoState::Closed; - return Ok(ProcessMsgResult::Break); - } - } - Ok(ProcessMsgResult::Continue) - } - - async fn process_message( - &mut self, - handler: &mut impl Handler, - msg: FeMessage, - unnamed_query_string: &mut Bytes, - ) -> Result { - // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth - // TODO: change that to proper top-level match of protocol state with separate message handling for each state - assert!(self.state == ProtoState::Established); - - match msg { - FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => { - return Err(QueryError::Other(anyhow::anyhow!("protocol violation"))); - } - - FeMessage::Query(body) => { - // remove null terminator - let query_string = cstr_to_str(&body)?; - - trace!("got query {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string).await { - log_query_error(query_string, &e); - let short_error = short_error(&e); - self.write_message_noflush(&BeMessage::ErrorResponse( - &short_error, - Some(e.pg_error_code()), - ))?; - } - self.write_message_noflush(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Parse(m) => { - *unnamed_query_string = m.query_string; - self.write_message_noflush(&BeMessage::ParseComplete)?; - } - - FeMessage::Describe(_) => { - self.write_message_noflush(&BeMessage::ParameterDescription)? - .write_message_noflush(&BeMessage::NoData)?; - } - - FeMessage::Bind(_) => { - self.write_message_noflush(&BeMessage::BindComplete)?; - } - - FeMessage::Close(_) => { - self.write_message_noflush(&BeMessage::CloseComplete)?; - } - - FeMessage::Execute(_) => { - let query_string = cstr_to_str(unnamed_query_string)?; - trace!("got execute {query_string:?}"); - if let Err(e) = handler.process_query(self, query_string).await { - log_query_error(query_string, &e); - self.write_message_noflush(&BeMessage::ErrorResponse( - &e.to_string(), - Some(e.pg_error_code()), - ))?; - } - // NOTE there is no ReadyForQuery message. This handler is used - // for basebackup and it uses CopyOut which doesn't require - // ReadyForQuery message and backend just switches back to - // processing mode after sending CopyDone or ErrorResponse. - } - - FeMessage::Sync => { - self.write_message_noflush(&BeMessage::ReadyForQuery)?; - } - - FeMessage::Terminate => { - return Ok(ProcessMsgResult::Break); - } - - // We prefer explicit pattern matching to wildcards, because - // this helps us spot the places where new variants are missing - FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message type: {:?}", - msg - ))); - } - } - - Ok(ProcessMsgResult::Continue) - } -} - -/// -/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData -/// messages. -/// - -pub struct CopyDataWriter<'a> { - pgb: &'a mut PostgresBackend, -} - -impl<'a> AsyncWrite for CopyDataWriter<'a> { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.get_mut(); - - // It's not strictly required to flush between each message, but makes it easier - // to view in wireshark, and usually the messages that the callers write are - // decently-sized anyway. - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } - - // CopyData - // XXX: if the input is large, we should split it into multiple messages. - // Not sure what the threshold should be, but the ultimate hard limit is that - // the length cannot exceed u32. - this.pgb.write_message_noflush(&BeMessage::CopyData(buf))?; - - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.get_mut(); - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } - this.pgb.poll_flush(cx) - } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.get_mut(); - match ready!(this.pgb.poll_write_buf(cx)) { - Ok(()) => {} - Err(err) => return Poll::Ready(Err(err)), - } - this.pgb.poll_flush(cx) - } -} - -pub fn short_error(e: &QueryError) -> String { - match e { - QueryError::Disconnected(connection_error) => connection_error.to_string(), - QueryError::Other(e) => format!("{e:#}"), - } -} - -pub(super) fn log_query_error(query: &str, e: &QueryError) { - match e { - QueryError::Disconnected(ConnectionError::Socket(io_error)) => { - if is_expected_io_error(io_error) { - info!("query handler for '{query}' failed with expected io error: {io_error}"); - } else { - error!("query handler for '{query}' failed with io error: {io_error}"); - } - } - QueryError::Disconnected(other_connection_error) => { - error!("query handler for '{query}' failed with connection error: {other_connection_error:?}") - } - QueryError::Other(e) => { - error!("query handler for '{query}' failed: {e:?}"); - } - } -} diff --git a/libs/utils/src/sock_split.rs b/libs/utils/src/sock_split.rs deleted file mode 100644 index b0e5a0bf6a..0000000000 --- a/libs/utils/src/sock_split.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{ - io::{self, BufReader, Write}, - net::{Shutdown, TcpStream}, - sync::Arc, -}; - -use rustls::Connection; - -/// Wrapper supporting reads of a shared TcpStream. -pub struct ArcTcpRead(Arc); - -impl io::Read for ArcTcpRead { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - (&*self.0).read(buf) - } -} - -impl std::ops::Deref for ArcTcpRead { - type Target = TcpStream; - - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -/// Wrapper around a TCP Stream supporting buffered reads. -pub struct BufStream(BufReader); - -impl io::Read for BufStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } -} - -impl io::Write for BufStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.get_ref().write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.get_ref().flush() - } -} - -impl BufStream { - /// Unwrap into the internal BufReader. - fn into_reader(self) -> BufReader { - self.0 - } - - /// Returns a reference to the underlying TcpStream. - fn get_ref(&self) -> &TcpStream { - &self.0.get_ref().0 - } -} - -pub enum ReadStream { - Tcp(BufReader), - Tls(rustls_split::ReadHalf), -} - -impl io::Read for ReadStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - Self::Tcp(reader) => reader.read(buf), - Self::Tls(read_half) => read_half.read(buf), - } - } -} - -impl ReadStream { - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.get_ref().shutdown(how), - Self::Tls(write_half) => write_half.shutdown(how), - } - } -} - -pub enum WriteStream { - Tcp(Arc), - Tls(rustls_split::WriteHalf), -} - -impl WriteStream { - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.shutdown(how), - Self::Tls(write_half) => write_half.shutdown(how), - } - } -} - -impl io::Write for WriteStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.as_ref().write(buf), - Self::Tls(write_half) => write_half.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.as_ref().flush(), - Self::Tls(write_half) => write_half.flush(), - } - } -} - -type TlsStream = rustls::StreamOwned; - -pub enum BidiStream { - Tcp(BufStream), - /// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`]. - Tls(Box>), -} - -impl BidiStream { - pub fn from_tcp(stream: TcpStream) -> Self { - Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream))))) - } - - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.get_ref().shutdown(how), - Self::Tls(tls_boxed) => { - if how == Shutdown::Read { - tls_boxed.sock.get_ref().shutdown(how) - } else { - tls_boxed.conn.send_close_notify(); - let res = tls_boxed.flush(); - tls_boxed.sock.get_ref().shutdown(how)?; - res - } - } - } - } - - /// Split the bi-directional stream into two owned read and write halves. - pub fn split(self) -> (ReadStream, WriteStream) { - match self { - Self::Tcp(stream) => { - let reader = stream.into_reader(); - let stream: Arc = reader.get_ref().0.clone(); - - (ReadStream::Tcp(reader), WriteStream::Tcp(stream)) - } - Self::Tls(tls_boxed) => { - let reader = tls_boxed.sock.into_reader(); - let buffer_data = reader.buffer().to_owned(); - let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192); - let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192); - - // TODO would be nice to avoid the Arc here - let socket = Arc::try_unwrap(reader.into_inner().0).unwrap(); - - let (read_half, write_half) = rustls_split::split( - socket, - Connection::Server(tls_boxed.conn), - read_buf_cfg, - write_buf_cfg, - ); - (ReadStream::Tls(read_half), WriteStream::Tls(write_half)) - } - } - } - - pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result { - match self { - Self::Tcp(mut stream) => { - conn.complete_io(&mut stream)?; - assert!(!conn.is_handshaking()); - Ok(Self::Tls(Box::new(TlsStream::new(conn, stream)))) - } - Self::Tls { .. } => Err(io::Error::new( - io::ErrorKind::InvalidInput, - "TLS is already started on this stream", - )), - } - } -} - -impl io::Read for BidiStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.read(buf), - Self::Tls(tls_boxed) => tls_boxed.read(buf), - } - } -} - -impl io::Write for BidiStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Self::Tcp(stream) => stream.write(buf), - Self::Tls(tls_boxed) => tls_boxed.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Self::Tcp(stream) => stream.flush(), - Self::Tls(tls_boxed) => tls_boxed.flush(), - } - } -} diff --git a/libs/utils/tests/ssl_test.rs b/libs/utils/tests/ssl_test.rs deleted file mode 100644 index fae707f049..0000000000 --- a/libs/utils/tests/ssl_test.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::{ - collections::HashMap, - io::{Cursor, Read, Write}, - net::{TcpListener, TcpStream}, - sync::Arc, -}; - -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use once_cell::sync::Lazy; - -use utils::{ - postgres_backend::{AuthType, Handler, PostgresBackend}, - postgres_backend_async::QueryError, -}; - -fn make_tcp_pair() -> (TcpStream, TcpStream) { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - let client_stream = TcpStream::connect(addr).unwrap(); - let (server_stream, _) = listener.accept().unwrap(); - (server_stream, client_stream) -} - -static KEY: Lazy = Lazy::new(|| { - let mut cursor = Cursor::new(include_bytes!("key.pem")); - rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()) -}); - -static CERT: Lazy = Lazy::new(|| { - let mut cursor = Cursor::new(include_bytes!("cert.pem")); - rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) -}); - -#[test] -// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274), -// we resize the vector so doing some modifications after all -#[allow(clippy::read_zero_byte_vec)] -fn ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - const QUERY: &str = "hello world"; - - let client_jh = std::thread::spawn(move || { - // SSLRequest - client_sock.write_u32::(8).unwrap(); - client_sock.write_u32::(80877103).unwrap(); - - let ssl_response = client_sock.read_u8().unwrap(); - assert_eq!(b'S', ssl_response); - - let cfg = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates({ - let mut store = rustls::RootCertStore::empty(); - store.add(&CERT).unwrap(); - store - }) - .with_no_client_auth(); - let client_config = Arc::new(cfg); - - let dns_name = "localhost".try_into().unwrap(); - let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap(); - - conn.complete_io(&mut client_sock).unwrap(); - assert!(!conn.is_handshaking()); - - let mut stream = rustls::Stream::new(&mut conn, &mut client_sock); - - // StartupMessage - stream.write_u32::(9).unwrap(); - stream.write_u32::(196608).unwrap(); - stream.write_u8(0).unwrap(); - stream.flush().unwrap(); - - // wait for ReadyForQuery - let mut msg_buf = Vec::new(); - loop { - let msg = stream.read_u8().unwrap(); - let size = stream.read_u32::().unwrap() - 4; - msg_buf.resize(size as usize, 0); - stream.read_exact(&mut msg_buf).unwrap(); - - if msg == b'Z' { - // ReadyForQuery - break; - } - } - - // Query - stream.write_u8(b'Q').unwrap(); - stream - .write_u32::(4u32 + QUERY.len() as u32) - .unwrap(); - stream.write_all(QUERY.as_ref()).unwrap(); - stream.flush().unwrap(); - - // ReadyForQuery - let msg = stream.read_u8().unwrap(); - assert_eq!(msg, b'Z'); - }); - - struct TestHandler { - got_query: bool, - } - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - query_string: &str, - ) -> Result<(), QueryError> { - self.got_query = query_string == QUERY; - Ok(()) - } - } - let mut handler = TestHandler { got_query: false }; - - let cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![CERT.clone()], KEY.clone()) - .unwrap(); - let tls_config = Some(Arc::new(cfg)); - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap(); - pgb.run(&mut handler).unwrap(); - assert!(handler.got_query); - - client_jh.join().unwrap(); - - // TODO consider shutdown behavior -} - -#[test] -fn no_ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - let client_jh = std::thread::spawn(move || { - let mut buf = BytesMut::new(); - - // SSLRequest - buf.put_u32(8); - buf.put_u32(80877103); - client_sock.write_all(&buf).unwrap(); - buf.clear(); - - let ssl_response = client_sock.read_u8().unwrap(); - assert_eq!(b'N', ssl_response); - }); - - struct TestHandler; - - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - _query_string: &str, - ) -> Result<(), QueryError> { - panic!() - } - } - - let mut handler = TestHandler; - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap(); - pgb.run(&mut handler).unwrap(); - - client_jh.join().unwrap(); -} - -#[test] -fn server_forces_ssl() { - let (mut client_sock, server_sock) = make_tcp_pair(); - - let client_jh = std::thread::spawn(move || { - // StartupMessage - client_sock.write_u32::(9).unwrap(); - client_sock.write_u32::(196608).unwrap(); - client_sock.write_u8(0).unwrap(); - client_sock.flush().unwrap(); - - // ErrorResponse - assert_eq!(client_sock.read_u8().unwrap(), b'E'); - let len = client_sock.read_u32::().unwrap() - 4; - - let mut body = vec![0; len as usize]; - client_sock.read_exact(&mut body).unwrap(); - let mut body = Bytes::from(body); - - let mut errors = HashMap::new(); - loop { - let field_type = body.get_u8(); - if field_type == 0u8 { - break; - } - - let end_idx = body.iter().position(|&b| b == 0u8).unwrap(); - let mut value = body.split_to(end_idx + 1); - assert_eq!(value[end_idx], 0u8); - value.truncate(end_idx); - let old = errors.insert(field_type, value); - assert!(old.is_none()); - } - - assert!(!body.has_remaining()); - - assert_eq!("must connect with TLS", errors.get(&b'M').unwrap()); - - // TODO read failure - }); - - struct TestHandler; - impl Handler for TestHandler { - fn process_query( - &mut self, - _pgb: &mut PostgresBackend, - _query_string: &str, - ) -> Result<(), QueryError> { - panic!() - } - } - let mut handler = TestHandler; - - let cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![CERT.clone()], KEY.clone()) - .unwrap(); - let tls_config = Some(Arc::new(cfg)); - - let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap(); - let res = pgb.run(&mut handler).unwrap_err(); - assert_eq!("client did not connect with TLS", format!("{}", res)); - - client_jh.join().unwrap(); - - // TODO consider shutdown behavior -} diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index d2f0b84863..8d6641a387 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -37,6 +37,7 @@ num-traits.workspace = true once_cell.workspace = true pin-project-lite.workspace = true postgres.workspace = true +postgres_backend.workspace = true postgres-protocol.workspace = true postgres-types.workspace = true rand.workspace = true diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 01a2c85d74..0441760eef 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -23,11 +23,10 @@ use pageserver::{ tenant::mgr, virtual_file, }; +use postgres_backend::AuthType; use utils::{ auth::JwtAuth, - logging, - postgres_backend::AuthType, - project_git_version, + logging, project_git_version, sentry_init::init_sentry, signals::{self, Signal}, tcp_listener, diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 309e5367a4..b3525c2cc6 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -21,10 +21,10 @@ use std::time::Duration; use toml_edit; use toml_edit::{Document, Item}; +use postgres_backend::AuthType; use utils::{ id::{NodeId, TenantId, TimelineId}, logging::LogFormat, - postgres_backend::AuthType, }; use crate::tenant::config::TenantConf; diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index b362e25424..40e11a70b7 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -20,7 +20,8 @@ use pageserver_api::models::{ PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamNblocksRequest, PagestreamNblocksResponse, }; -use pq_proto::ConnectionError; +use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError}; +use pq_proto::framed::ConnectionError; use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; use std::io; @@ -35,8 +36,6 @@ use utils::{ auth::{Claims, JwtAuth, Scope}, id::{TenantId, TimelineId}, lsn::Lsn, - postgres_backend::AuthType, - postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError}, simple_rcu::RcuReadGuard, }; @@ -68,7 +67,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { msg } + msg = pgb.read_message() => { msg.map_err(QueryError::from)} }; match msg { @@ -79,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream continue, FeMessage::Terminate => { let msg = "client terminated connection with Terminate message during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; break; } m => { let msg = format!("unexpected message {m:?}"); - pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None))?; + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?; Err(io::Error::new(io::ErrorKind::Other, msg))?; break; } @@ -96,16 +97,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { let msg = "client closed connection during COPY"; - let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); - pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?; + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + // error can't happen here, ErrorResponse serialization should be always ok + pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; pgb.flush().await?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { Err(io_error)?; } Err(other) => { - Err(io::Error::new(io::ErrorKind::Other, other))?; + Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?; } }; } @@ -212,7 +214,7 @@ async fn page_service_conn_main( // we've been requested to shut down Ok(()) } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { if is_expected_io_error(&io_error) { info!("Postgres client disconnected ({io_error})"); Ok(()) @@ -721,7 +723,7 @@ impl PageServerHandler { } #[async_trait::async_trait] -impl postgres_backend_async::Handler for PageServerHandler { +impl postgres_backend::Handler for PageServerHandler { fn check_auth_jwt( &mut self, _pgb: &mut PostgresBackend, @@ -1055,7 +1057,7 @@ impl From for QueryError { fn from(e: GetActiveTenantError) -> Self { match e { GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( - ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), + ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), ), GetActiveTenantError::Other(e) => QueryError::Other(e), } diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 7e06c398af..7194a4f3ed 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -33,10 +33,11 @@ use crate::{ walingest::WalIngest, walrecord::DecodedWALRecord, }; +use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::waldecoder::WalStreamDecoder; use pq_proto::ReplicationFeedback; -use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error}; +use utils::lsn::Lsn; /// Status of the connection. #[derive(Debug, Clone, Copy)] @@ -353,7 +354,7 @@ pub async fn handle_walreceiver_connection( debug!("neon_status_update {status_update:?}"); let mut data = BytesMut::new(); - status_update.serialize(&mut data)?; + status_update.serialize(&mut data); physical_stream .as_mut() .zenith_status_update(data.len() as u64, &data) @@ -434,8 +435,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result> = Lazy::new(Default::default); @@ -33,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N /// Console management API listener task. /// It spawns console response handlers needed for the link auth. -pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> { +pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { scopeguard::defer! { info!("mgmt has shut down"); } @@ -42,18 +40,12 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> let (socket, peer_addr) = listener.accept().await?; info!("accepted connection from {peer_addr}"); - let socket = socket.into_std()?; socket .set_nodelay(true) .context("failed to set client socket option")?; - socket - .set_nonblocking(false) - .context("failed to set client socket option")?; - // TODO: replace with async tasks. - thread::spawn(move || { - let tid = std::thread::current().id(); - let span = info_span!("mgmt", thread = format_args!("{tid:?}")); + tokio::task::spawn(async move { + let span = info_span!("mgmt", peer = %peer_addr); let _enter = span.enter(); info!("started a new console management API thread"); @@ -61,16 +53,16 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> info!("console management API thread is about to finish"); } - if let Err(e) = handle_connection(socket) { + if let Err(e) = handle_connection(socket).await { error!("thread failed with an error: {e}"); } }); } } -fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { - let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?; - pgbackend.run(&mut MgmtHandler) +async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { + let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?; + pgbackend.run(&mut MgmtHandler, future::pending::<()>).await } /// A message received by `mgmt` when a compute node is ready. @@ -78,16 +70,21 @@ pub type ComputeReady = Result; // TODO: replace with an http-based protocol. struct MgmtHandler; +#[async_trait::async_trait] impl postgres_backend::Handler for MgmtHandler { - fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { - try_process_query(pgb, query).map_err(|e| { + async fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query: &str, + ) -> Result<(), QueryError> { + try_process_query(pgb, query).await.map_err(|e| { error!("failed to process response: {e:?}"); e }) } } -fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { +async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> { let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?; let span = info_span!("event", session_id = resp.session_id); @@ -98,11 +95,11 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? - .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; } Err(e) => { error!("failed to deliver response to per-client task"); - pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?; + pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 02a0fabe9a..5a802dafb2 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,45 +1,40 @@ use crate::error::UserFacingError; use anyhow::bail; -use bytes::BytesMut; use pin_project_lite::pin_project; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket}; +use pq_proto::framed::{ConnectionError, Framed}; +use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; use std::pin::Pin; use std::sync::Arc; use std::{io, task}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; -pin_project! { - /// Stream wrapper which implements libpq's protocol. - /// NOTE: This object deliberately doesn't implement [`AsyncRead`] - /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying - /// to pass random malformed bytes through the connection). - pub struct PqStream { - #[pin] - stream: S, - buffer: BytesMut, - } +/// Stream wrapper which implements libpq's protocol. +/// NOTE: This object deliberately doesn't implement [`AsyncRead`] +/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying +/// to pass random malformed bytes through the connection). +pub struct PqStream { + framed: Framed, } impl PqStream { /// Construct a new libpq protocol wrapper. pub fn new(stream: S) -> Self { Self { - stream, - buffer: Default::default(), + framed: Framed::new(stream), } } /// Extract the underlying stream. pub fn into_inner(self) -> S { - self.stream + self.framed.into_inner() } /// Get a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { - &self.stream + self.framed.get_ref() } } @@ -50,16 +45,19 @@ fn err_connection() -> io::Error { impl PqStream { /// Receive [`FeStartupPacket`], which is a first packet sent by a client. pub async fn read_startup_packet(&mut self) -> io::Result { - // TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket` - let msg = FeStartupPacket::read_fut(&mut self.stream) + self.framed + .read_startup_message() .await .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection)?; + .ok_or_else(err_connection) + } - match msg { - FeMessage::StartupPacket(packet) => Ok(packet), - _ => panic!("unreachable state"), - } + async fn read_message(&mut self) -> io::Result { + self.framed + .read_message() + .await + .map_err(ConnectionError::into_io_error)? + .ok_or_else(err_connection) } pub async fn read_password_message(&mut self) -> io::Result { @@ -71,19 +69,14 @@ impl PqStream { )), } } - - async fn read_message(&mut self) -> io::Result { - FeMessage::read_fut(&mut self.stream) - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } } impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buffer, message)?; + self.framed + .write_message(message) + .map_err(ProtocolError::into_io_error)?; Ok(self) } @@ -96,9 +89,7 @@ impl PqStream { /// Flush the output buffer into the underlying stream. pub async fn flush(&mut self) -> io::Result<&mut Self> { - self.stream.write_all(&self.buffer).await?; - self.buffer.clear(); - self.stream.flush().await?; + self.framed.flush().await?; Ok(self) } diff --git a/run_clippy.sh b/run_clippy.sh index fe0e745d7d..0558541089 100755 --- a/run_clippy.sh +++ b/run_clippy.sh @@ -11,12 +11,18 @@ # Not every feature is supported in macOS builds. Avoid running regular linting # script that checks every feature. +# +# manual-range-contains wants +# !(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len) +# instead of +# len < 4 || len > MAX_STARTUP_PACKET_LENGTH +# , let's disagree. if [[ "$OSTYPE" == "darwin"* ]]; then # no extra features to test currently, add more here when needed - cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -D warnings + cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -A clippy::manual-range-contains -D warnings else # * `-A unknown_lints` – do not warn about unknown lint suppressions # that people with newer toolchains might use # * `-D warnings` - fail on any warnings (`cargo` returns non-zero exit status) - cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -D warnings + cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -A clippy::manual-range-contains -D warnings fi diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 4ee8d82203..88f950d6c8 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -35,6 +35,7 @@ toml_edit.workspace = true tracing.workspace = true url.workspace = true metrics.workspace = true +postgres_backend.workspace = true postgres_ffi.workspace = true pq_proto.workspace = true remote_storage.workspace = true diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 683050e9cd..d2cb9f79b9 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -236,7 +236,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { let conf_cloned = conf.clone(); let safekeeper_thread = thread::Builder::new() - .name("safekeeper thread".into()) + .name("WAL service thread".into()) .spawn(|| wal_service::thread_main(conf_cloned, pg_listener)) .unwrap(); diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index 60df5dd372..be1c89c97b 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -1,27 +1,23 @@ //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres //! protocol commands. +use anyhow::Context; +use std::str; +use tracing::{info, info_span, Instrument}; + use crate::auth::check_permission; use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage}; -use crate::receive_wal::ReceiveWalConn; - -use crate::send_wal::ReplicationConn; use crate::{GlobalTimelines, SafeKeeperConf}; -use anyhow::Context; - +use postgres_backend::QueryError; +use postgres_backend::{self, PostgresBackend}; use postgres_ffi::PG_TLI; -use regex::Regex; - use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID}; -use std::str; -use tracing::info; +use regex::Regex; use utils::auth::{Claims, Scope}; -use utils::postgres_backend_async::QueryError; use utils::{ id::{TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, - postgres_backend::{self, PostgresBackend}, }; /// Safekeeper handler of postgres commands @@ -53,7 +49,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { let start_lsn = caps .next() .map(|cap| cap[1].parse::()) - .context("failed to parse start LSN from START_REPLICATION command")??; + .context("parse start LSN from START_REPLICATION command")??; Ok(SafekeeperPostgresCommand::StartReplication { start_lsn }) } else if cmd.starts_with("IDENTIFY_SYSTEM") { Ok(SafekeeperPostgresCommand::IdentifySystem) @@ -67,6 +63,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { } } +#[async_trait::async_trait] impl postgres_backend::Handler for SafekeeperPostgresHandler { // tenant_id and timeline_id are passed in connection string params fn startup( @@ -137,7 +134,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { Ok(()) } - fn process_query( + async fn process_query( &mut self, pgb: &mut PostgresBackend, query_string: &str, @@ -147,9 +144,10 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { .starts_with("set datestyle to ") { // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect - pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; return Ok(()); } + let cmd = parse_cmd(query_string)?; info!( @@ -161,26 +159,23 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler { let timeline_id = self.timeline_id.context("timelineid is required")?; self.check_permission(Some(tenant_id))?; self.ttid = TenantTimelineId::new(tenant_id, timeline_id); + let span_ttid = self.ttid; // satisfy borrow checker - let res = match cmd { - SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self), + match cmd { + SafekeeperPostgresCommand::StartWalPush => { + self.handle_start_wal_push(pgb) + .instrument(info_span!("WAL receiver", ttid = %span_ttid)) + .await + } SafekeeperPostgresCommand::StartReplication { start_lsn } => { - ReplicationConn::new(pgb).run(self, pgb, start_lsn) + self.handle_start_replication(pgb, start_lsn) + .instrument(info_span!("WAL sender", ttid = %span_ttid)) + .await } - SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb), - SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd), - }; - - match res { - Ok(()) => Ok(()), - Err(QueryError::Disconnected(connection_error)) => { - info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}"); - Err(QueryError::Disconnected(connection_error)) + SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await, + SafekeeperPostgresCommand::JSONCtrl { ref cmd } => { + handle_json_ctrl(self, pgb, cmd).await } - Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!( - "Failed to process query for timeline {}", - self.ttid - )))), } } } @@ -217,7 +212,10 @@ impl SafekeeperPostgresHandler { /// /// Handle IDENTIFY_SYSTEM replication command /// - fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> { + async fn handle_identify_system( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), QueryError> { let tli = GlobalTimelines::get(self.ttid)?; let lsn = if self.is_walproposer_recovery() { @@ -267,7 +265,7 @@ impl SafekeeperPostgresHandler { Some(lsn_bytes), None, ]))? - .write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; + .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; Ok(()) } diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index a917d61678..7e5d7d1b47 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -8,6 +8,7 @@ use serde::Serialize; use serde::Serializer; use std::collections::{HashMap, HashSet}; use std::fmt::Display; + use std::sync::Arc; use storage_broker::proto::SafekeeperTimelineInfo; use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId; @@ -181,12 +182,9 @@ async fn timeline_create_handler(mut request: Request) -> Result anyhow::Result> { +async fn prepare_safekeeper( + ttid: TenantTimelineId, + pg_version: u32, +) -> anyhow::Result> { GlobalTimelines::create( ttid, ServerInfo { @@ -106,6 +110,7 @@ fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result Lsn::INVALID, Lsn::INVALID, ) + .await } fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> anyhow::Result<()> { @@ -128,15 +133,15 @@ fn send_proposer_elected(tli: &Arc, term: Term, lsn: Lsn) -> anyhow::R } #[derive(Debug, Serialize, Deserialize)] -struct InsertedWAL { +pub struct InsertedWAL { begin_lsn: Lsn, - end_lsn: Lsn, + pub end_lsn: Lsn, append_response: AppendResponse, } /// Extend local WAL with new LogicalMessage record. To do that, /// create AppendRequest with new WAL and pass it to safekeeper. -fn append_logical_message( +pub fn append_logical_message( tli: &Arc, msg: &AppendLogicalMessage, ) -> anyhow::Result { diff --git a/safekeeper/src/lib.rs b/safekeeper/src/lib.rs index 891d73533f..818f2f9424 100644 --- a/safekeeper/src/lib.rs +++ b/safekeeper/src/lib.rs @@ -1,8 +1,7 @@ -use storage_broker::Uri; -// use remote_storage::RemoteStorageConfig; use std::path::PathBuf; use std::time::Duration; +use storage_broker::Uri; use utils::id::{NodeId, TenantId, TenantTimelineId}; diff --git a/safekeeper/src/receive_wal.rs b/safekeeper/src/receive_wal.rs index 671e5470a0..22c9871026 100644 --- a/safekeeper/src/receive_wal.rs +++ b/safekeeper/src/receive_wal.rs @@ -2,204 +2,284 @@ //! Gets messages from the network, passes them down to consensus module and //! sends replies back. -use anyhow::anyhow; -use anyhow::Context; - -use bytes::BytesMut; -use tracing::*; -use utils::lsn::Lsn; -use utils::postgres_backend_async::QueryError; - +use crate::handler::SafekeeperPostgresHandler; +use crate::safekeeper::AcceptorProposerMessage; +use crate::safekeeper::ProposerAcceptorMessage; use crate::safekeeper::ServerInfo; use crate::timeline::Timeline; use crate::GlobalTimelines; - +use anyhow::{anyhow, Context}; +use bytes::BytesMut; +use nix::unistd::gettid; +use postgres_backend::CopyStreamHandlerEnd; +use postgres_backend::PostgresBackend; +use postgres_backend::PostgresBackendReader; +use postgres_backend::QueryError; +use pq_proto::BeMessage; use std::net::SocketAddr; -use std::sync::mpsc::channel; -use std::sync::mpsc::Receiver; - use std::sync::Arc; use std::thread; +use std::thread::JoinHandle; +use tokio::sync::mpsc::channel; +use tokio::sync::mpsc::error::TryRecvError; +use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::Sender; +use tokio::task::spawn_blocking; +use tracing::*; +use utils::id::TenantTimelineId; +use utils::lsn::Lsn; -use crate::safekeeper::AcceptorProposerMessage; -use crate::safekeeper::ProposerAcceptorMessage; +const MSG_QUEUE_SIZE: usize = 256; +const REPLY_QUEUE_SIZE: usize = 16; -use crate::handler::SafekeeperPostgresHandler; -use pq_proto::{BeMessage, FeMessage}; -use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream}; - -pub struct ReceiveWalConn<'pg> { - /// Postgres connection - pg_backend: &'pg mut PostgresBackend, - /// The cached result of `pg_backend.socket().peer_addr()` (roughly) - peer_addr: SocketAddr, -} - -impl<'pg> ReceiveWalConn<'pg> { - pub fn new(pg: &'pg mut PostgresBackend) -> ReceiveWalConn<'pg> { - let peer_addr = *pg.get_peer_addr(); - ReceiveWalConn { - pg_backend: pg, - peer_addr, +impl SafekeeperPostgresHandler { + /// Wrapper around handle_start_wal_push_guts handling result. Error is + /// handled here while we're still in walreceiver ttid span; with API + /// extension, this can probably be moved into postgres_backend. + pub async fn handle_start_wal_push( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), QueryError> { + if let Err(end) = self.handle_start_wal_push_guts(pgb).await { + // Log the result and probably send it to the client, closing the stream. + pgb.handle_copy_stream_end(end).await; } - } - - // Send message to the postgres - fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> { - let mut buf = BytesMut::with_capacity(128); - msg.serialize(&mut buf)?; - self.pg_backend.write_message(&BeMessage::CopyData(&buf))?; Ok(()) } - /// Receive WAL from wal_proposer - pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> { - let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered(); - + pub async fn handle_start_wal_push_guts( + &mut self, + pgb: &mut PostgresBackend, + ) -> Result<(), CopyStreamHandlerEnd> { // Notify the libpq client that it's allowed to send `CopyData` messages - self.pg_backend - .write_message(&BeMessage::CopyBothResponse)?; + pgb.write_message(&BeMessage::CopyBothResponse).await?; - let r = self - .pg_backend - .take_stream_in() - .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?; - let mut poll_reader = ProposerPollStream::new(r)?; + // Experiments [1] confirm that doing network IO in one (this) thread and + // processing with disc IO in another significantly improves + // performance; we spawn off WalAcceptor thread for message processing + // to this end. + // + // [1] https://github.com/neondatabase/neon/pull/1318 + let (msg_tx, msg_rx) = channel(MSG_QUEUE_SIZE); + let (reply_tx, reply_rx) = channel(REPLY_QUEUE_SIZE); + let mut acceptor_handle: Option>> = None; - // Receive information about server - let next_msg = poll_reader.recv_msg()?; - let tli = match next_msg { - ProposerAcceptorMessage::Greeting(ref greeting) => { - info!( - "start handshake with walproposer {} sysid {} timeline {}", - self.peer_addr, greeting.system_id, greeting.tli, - ); - let server_info = ServerInfo { - pg_version: greeting.pg_version, - system_id: greeting.system_id, - wal_seg_size: greeting.wal_seg_size, - }; - GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)? - } - _ => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message {next_msg:?} instead of greeting" - ))) - } + // Concurrently receive and send data; replies are not synchronized with + // sends, so this avoids deadlocks. + let mut pgb_reader = pgb.split().context("START_WAL_PUSH split")?; + let peer_addr = *pgb.get_peer_addr(); + let res = tokio::select! { + // todo: add read|write .context to these errors + r = read_network(self.ttid, &mut pgb_reader, peer_addr, msg_tx, &mut acceptor_handle, msg_rx, reply_tx) => r, + r = write_network(pgb, reply_rx) => r, }; - let mut next_msg = Some(next_msg); + // Join pg backend back. + pgb.unsplit(pgb_reader)?; - let mut first_time_through = true; - let mut _guard: Option = None; - loop { - if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) { - // poll AppendRequest's without blocking and write WAL to disk without flushing, - // while it's readily available - while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg { - let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); - - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - - next_msg = poll_reader.poll_msg(); - } - - // flush all written WAL to the disk - let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } else if let Some(msg) = next_msg.take() { - // process other message - let reply = tli.process_msg(&msg)?; - if let Some(reply) = reply { - self.write_msg(&reply)?; - } - } - if first_time_through { - // Register the connection and defer unregister. Do that only - // after processing first message, as it sets wal_seg_size, - // wanted by many. - tli.on_compute_connect()?; - _guard = Some(ComputeConnectionGuard { - timeline: Arc::clone(&tli), - }); - first_time_through = false; + // Join the spawned WalAcceptor. At this point chans to/from it passed + // to network routines are dropped, so it will exit as soon as it + // touches them. + match acceptor_handle { + None => { + // failed even before spawning; read_network should have error + Err(res.expect_err("no error with WalAcceptor not spawn")) } + Some(handle) => { + let wal_acceptor_res = handle.join(); - // blocking wait for the next message - if next_msg.is_none() { - next_msg = Some(poll_reader.recv_msg()?); + // If there was any network error, return it. + res?; + + // Otherwise, WalAcceptor thread must have errored. + match wal_acceptor_res { + Ok(Ok(_)) => Ok(()), // can't happen currently; would be if we add graceful termination + Ok(Err(e)) => Err(CopyStreamHandlerEnd::Other(e.context("WAL acceptor"))), + Err(_) => Err(CopyStreamHandlerEnd::Other(anyhow!( + "WalAcceptor thread panicked", + ))), + } } } } } -struct ProposerPollStream { - msg_rx: Receiver, - read_thread: Option>>, +/// Read next message from walproposer. +/// TODO: Return Ok(None) on graceful termination. +async fn read_message( + pgb_reader: &mut PostgresBackendReader, +) -> Result { + let copy_data = pgb_reader.read_copy_message().await?; + let msg = ProposerAcceptorMessage::parse(copy_data)?; + Ok(msg) } -impl ProposerPollStream { - fn new(mut r: ReadStream) -> anyhow::Result { - let (msg_tx, msg_rx) = channel(); - - let read_thread = thread::Builder::new() - .name("Read WAL thread".into()) - .spawn(move || -> Result<(), QueryError> { - loop { - let copy_data = match FeMessage::read(&mut r)? { - Some(FeMessage::CopyData(bytes)) => Ok(bytes), - Some(msg) => Err(QueryError::Other(anyhow::anyhow!( - "expected `CopyData` message, found {msg:?}" - ))), - None => Err(QueryError::from(std::io::Error::new( - std::io::ErrorKind::ConnectionAborted, - "walproposer closed the connection", - ))), - }?; - - let msg = ProposerAcceptorMessage::parse(copy_data)?; - msg_tx - .send(msg) - .context("Failed to send the proposer message")?; - } - // msg_tx will be dropped here, this will also close msg_rx - })?; - - Ok(Self { - msg_rx, - read_thread: Some(read_thread), - }) - } - - fn recv_msg(&mut self) -> Result { - self.msg_rx.recv().map_err(|_| { - // return error from the read thread - let res = match self.read_thread.take() { - Some(thread) => thread.join(), - None => return QueryError::Other(anyhow::anyhow!("read thread is gone")), +/// Read messages from socket and pass it to WalAcceptor thread. Returns Ok(()) +/// if msg_tx closed; it must mean WalAcceptor terminated, joining it should +/// tell the error. +async fn read_network( + ttid: TenantTimelineId, + pgb_reader: &mut PostgresBackendReader, + peer_addr: SocketAddr, + msg_tx: Sender, + // WalAcceptor is spawned when we learn server info from walproposer and + // create timeline; handle is put here. + acceptor_handle: &mut Option>>, + msg_rx: Receiver, + reply_tx: Sender, +) -> Result<(), CopyStreamHandlerEnd> { + // Receive information about server to create timeline, if not yet. + let next_msg = read_message(pgb_reader).await?; + let tli = match next_msg { + ProposerAcceptorMessage::Greeting(ref greeting) => { + info!( + "start handshake with walproposer {} sysid {} timeline {}", + peer_addr, greeting.system_id, greeting.tli, + ); + let server_info = ServerInfo { + pg_version: greeting.pg_version, + system_id: greeting.system_id, + wal_seg_size: greeting.wal_seg_size, }; + GlobalTimelines::create(ttid, server_info, Lsn::INVALID, Lsn::INVALID).await? + } + _ => { + return Err(CopyStreamHandlerEnd::Other(anyhow::anyhow!( + "unexpected message {next_msg:?} instead of greeting" + ))) + } + }; - match res { - Ok(Ok(())) => { - QueryError::Other(anyhow::anyhow!("unexpected result from read thread")) - } - Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")), - Ok(Err(err)) => err, + *acceptor_handle = Some( + WalAcceptor::spawn(tli.clone(), msg_rx, reply_tx).context("spawn WalAcceptor thread")?, + ); + + // Forward all messages to WalAcceptor + read_network_loop(pgb_reader, msg_tx, next_msg).await +} + +async fn read_network_loop( + pgb_reader: &mut PostgresBackendReader, + msg_tx: Sender, + mut next_msg: ProposerAcceptorMessage, +) -> Result<(), CopyStreamHandlerEnd> { + loop { + if msg_tx.send(next_msg).await.is_err() { + return Ok(()); // chan closed, WalAcceptor terminated + } + next_msg = read_message(pgb_reader).await?; + } +} + +/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(()) +/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should +/// tell the error. +async fn write_network( + pgb_writer: &mut PostgresBackend, + mut reply_rx: Receiver, +) -> Result<(), CopyStreamHandlerEnd> { + let mut buf = BytesMut::with_capacity(128); + + loop { + match reply_rx.recv().await { + Some(msg) => { + buf.clear(); + msg.serialize(&mut buf)?; + pgb_writer.write_message(&BeMessage::CopyData(&buf)).await?; } - }) + None => return Ok(()), // chan closed, WalAcceptor terminated + } + } +} + +/// Takes messages from msg_rx, processes and pushes replies to reply_tx. +struct WalAcceptor { + tli: Arc, + msg_rx: Receiver, + reply_tx: Sender, +} + +impl WalAcceptor { + /// Spawn thread with WalAcceptor running, return handle to it. + fn spawn( + tli: Arc, + msg_rx: Receiver, + reply_tx: Sender, + ) -> anyhow::Result>> { + let thread_name = format!("WAL acceptor {}", tli.ttid); + thread::Builder::new() + .name(thread_name) + .spawn(move || -> anyhow::Result<()> { + let mut wa = WalAcceptor { + tli, + msg_rx, + reply_tx, + }; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + let span_ttid = wa.tli.ttid; // satisfy borrow checker + runtime.block_on( + wa.run() + .instrument(info_span!("WAL acceptor", tid = %gettid(), ttid = %span_ttid)), + ) + }) + .map_err(anyhow::Error::from) } - fn poll_msg(&mut self) -> Option { - let res = self.msg_rx.try_recv(); + /// The main loop. Returns Ok(()) if either msg_rx or reply_tx got closed; + /// it must mean that network thread terminated. + async fn run(&mut self) -> anyhow::Result<()> { + // Register the connection and defer unregister. + self.tli.on_compute_connect().await?; + let _guard = ComputeConnectionGuard { + timeline: Arc::clone(&self.tli), + }; - match res { - Err(_) => None, - Ok(msg) => Some(msg), + let mut next_msg: ProposerAcceptorMessage; + + loop { + let opt_msg = self.msg_rx.recv().await; + if opt_msg.is_none() { + return Ok(()); // chan closed, streaming terminated + } + next_msg = opt_msg.unwrap(); + + if matches!(next_msg, ProposerAcceptorMessage::AppendRequest(_)) { + // loop through AppendRequest's while it's readily available to + // write as many WAL as possible without fsyncing + while let ProposerAcceptorMessage::AppendRequest(append_request) = next_msg { + let noflush_msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request); + + if let Some(reply) = self.tli.process_msg(&noflush_msg)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + + match self.msg_rx.try_recv() { + Ok(msg) => next_msg = msg, + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => return Ok(()), // chan closed, streaming terminated + } + } + + // flush all written WAL to the disk + if let Some(reply) = self.tli.process_msg(&ProposerAcceptorMessage::FlushWAL)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + } else { + // process message other than AppendRequest + if let Some(reply) = self.tli.process_msg(&next_msg)? { + if self.reply_tx.send(reply).await.is_err() { + return Ok(()); // chan closed, streaming terminated + } + } + } } } } @@ -210,8 +290,13 @@ struct ComputeConnectionGuard { impl Drop for ComputeConnectionGuard { fn drop(&mut self) { - if let Err(e) = self.timeline.on_compute_disconnect() { - error!("failed to unregister compute connection: {}", e); - } + let tli = self.timeline.clone(); + // tokio forbids to call blocking_send inside the runtime, and see + // comments in on_compute_disconnect why we call blocking_send. + spawn_blocking(move || { + if let Err(e) = tli.on_compute_disconnect() { + error!("failed to unregister compute connection: {}", e); + } + }); } } diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index fa973a3ede..2fcfce69f6 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -486,7 +486,7 @@ impl AcceptorProposerMessage { buf.put_u64_le(msg.hs_feedback.xmin); buf.put_u64_le(msg.hs_feedback.catalog_xmin); - msg.pageserver_feedback.serialize(buf)? + msg.pageserver_feedback.serialize(buf); } } diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 20600ab694..5c9f763b4a 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -5,24 +5,22 @@ use crate::handler::SafekeeperPostgresHandler; use crate::timeline::{ReplicaState, Timeline}; use crate::wal_storage::WalReader; use crate::GlobalTimelines; -use anyhow::Context; - +use anyhow::Context as AnyhowContext; use bytes::Bytes; +use postgres_backend::PostgresBackend; +use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError}; use postgres_ffi::get_current_timestamp; use postgres_ffi::{TimestampTz, MAX_SEND_SIZE}; +use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; use serde::{Deserialize, Serialize}; use std::cmp::min; -use std::net::Shutdown; +use std::str; use std::sync::Arc; use std::time::Duration; -use std::{io, str, thread}; -use utils::postgres_backend_async::QueryError; - -use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody}; use tokio::sync::watch::Receiver; use tokio::time::timeout; use tracing::*; -use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream}; +use utils::{bin_ser::BeSer, lsn::Lsn}; // See: https://www.postgresql.org/docs/13/protocol-replication.html const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h'; @@ -60,13 +58,6 @@ pub struct StandbyReply { pub reply_requested: bool, } -/// A network connection that's speaking the replication protocol. -pub struct ReplicationConn { - /// This is an `Option` because we will spawn a background thread that will - /// `take` it from us. - stream_in: Option, -} - /// Scope guard to unregister replication connection from timeline struct ReplicationConnGuard { replica: usize, // replica internal ID assigned by timeline @@ -79,230 +70,274 @@ impl Drop for ReplicationConnGuard { } } -impl ReplicationConn { - /// Create a new `ReplicationConn` - pub fn new(pgb: &mut PostgresBackend) -> Self { - Self { - stream_in: pgb.take_stream_in(), +impl SafekeeperPostgresHandler { + /// Wrapper around handle_start_replication_guts handling result. Error is + /// handled here while we're still in walsender ttid span; with API + /// extension, this can probably be moved into postgres_backend. + pub async fn handle_start_replication( + &mut self, + pgb: &mut PostgresBackend, + start_pos: Lsn, + ) -> Result<(), QueryError> { + if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await { + // Log the result and probably send it to the client, closing the stream. + pgb.handle_copy_stream_end(end).await; } - } - - /// Handle incoming messages from the network. - /// This is spawned into the background by `handle_start_replication`. - fn background_thread( - mut stream_in: ReadStream, - replica_guard: Arc, - ) -> anyhow::Result<()> { - let replica_id = replica_guard.replica; - let timeline = &replica_guard.timeline; - - let mut state = ReplicaState::new(); - // Wait for replica's feedback. - while let Some(msg) = FeMessage::read(&mut stream_in)? { - match &msg { - FeMessage::CopyData(m) => { - // There's three possible data messages that the client is supposed to send here: - // `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`. - - match m.first().cloned() { - Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { - // Note: deserializing is on m[1..] because we skip the tag byte. - state.hs_feedback = HotStandbyFeedback::des(&m[1..]) - .context("failed to deserialize HotStandbyFeedback")?; - timeline.update_replica_state(replica_id, state); - } - Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { - let _reply = StandbyReply::des(&m[1..]) - .context("failed to deserialize StandbyReply")?; - // This must be a regular postgres replica, - // because pageserver doesn't send this type of messages to safekeeper. - // Currently this is not implemented, so this message is ignored. - - warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet."); - // timeline.update_replica_state(replica_id, Some(state)); - } - Some(NEON_STATUS_UPDATE_TAG_BYTE) => { - // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. - let buf = Bytes::copy_from_slice(&m[9..]); - let reply = ReplicationFeedback::parse(buf); - - trace!("ReplicationFeedback is {:?}", reply); - // Only pageserver sends ReplicationFeedback, so set the flag. - // This replica is the source of information to resend to compute. - state.pageserver_feedback = Some(reply); - - timeline.update_replica_state(replica_id, state); - } - _ => warn!("unexpected message {:?}", msg), - } - } - FeMessage::Sync => {} - FeMessage::CopyFail => { - // Shutdown the connection, because rust-postgres client cannot be dropped - // when connection is alive. - let _ = stream_in.shutdown(Shutdown::Both); - anyhow::bail!("Copy failed"); - } - _ => { - // We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored. - info!("unexpected message {:?}", msg); - } - } - } - Ok(()) } - /// - /// Handle START_REPLICATION replication command - /// - pub fn run( + pub async fn handle_start_replication_guts( &mut self, - spg: &mut SafekeeperPostgresHandler, pgb: &mut PostgresBackend, - mut start_pos: Lsn, - ) -> Result<(), QueryError> { - let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered(); - - let tli = GlobalTimelines::get(spg.ttid)?; - - // spawn the background thread which receives HotStandbyFeedback messages. - let bg_timeline = Arc::clone(&tli); - let bg_stream_in = self.stream_in.take().unwrap(); - let bg_timeline_id = spg.timeline_id.unwrap(); + start_pos: Lsn, + ) -> Result<(), CopyStreamHandlerEnd> { + let appname = self.appname.clone(); + let tli = GlobalTimelines::get(self.ttid)?; let state = ReplicaState::new(); // This replica_id is used below to check if it's time to stop replication. - let replica_id = bg_timeline.add_replica(state); + let replica_id = tli.add_replica(state); // Use a guard object to remove our entry from the timeline, when the background // thread and us have both finished using it. - let replica_guard = Arc::new(ReplicationConnGuard { + let _guard = Arc::new(ReplicationConnGuard { replica: replica_id, - timeline: bg_timeline, + timeline: tli.clone(), }); - let bg_replica_guard = Arc::clone(&replica_guard); - // TODO: here we got two threads, one for writing WAL and one for receiving - // feedback. If one of them fails, we should shutdown the other one too. - let _ = thread::Builder::new() - .name("HotStandbyFeedback thread".into()) - .spawn(move || { - let _enter = - info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered(); - if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) { - error!("Replication background thread failed: {}", err); + // Walproposer gets special handling: safekeeper must give proposer all + // local WAL till the end, whether committed or not (walproposer will + // hang otherwise). That's because walproposer runs the consensus and + // synchronizes safekeepers on the most advanced one. + // + // There is a small risk of this WAL getting concurrently garbaged if + // another compute rises which collects majority and starts fixing log + // on this safekeeper itself. That's ok as (old) proposer will never be + // able to commit such WAL. + let stop_pos: Option = if self.is_walproposer_recovery() { + let wal_end = tli.get_flush_lsn(); + Some(wal_end) + } else { + None + }; + let end_pos = stop_pos.unwrap_or(Lsn::INVALID); + + info!( + "starting streaming from {:?} till {:?}", + start_pos, stop_pos + ); + + // switch to copy + pgb.write_message(&BeMessage::CopyBothResponse).await?; + + let (_, persisted_state) = tli.get_state(); + let wal_reader = WalReader::new( + self.conf.workdir.clone(), + self.conf.timeline_dir(&tli.ttid), + &persisted_state, + start_pos, + self.conf.wal_backup_enabled, + )?; + + // Split to concurrently receive and send data; replies are generally + // not synchronized with sends, so this avoids deadlocks. + let reader = pgb.split().context("START_REPLICATION split")?; + + let mut sender = WalSender { + pgb, + tli: tli.clone(), + appname, + start_pos, + end_pos, + stop_pos, + commit_lsn_watch_rx: tli.get_commit_lsn_watch_rx(), + replica_id, + wal_reader, + send_buf: [0; MAX_SEND_SIZE], + }; + let mut reply_reader = ReplyReader { + reader, + tli, + replica_id, + feedback: ReplicaState::new(), + }; + + let res = tokio::select! { + // todo: add read|write .context to these errors + r = sender.run() => r, + r = reply_reader.run() => r, + }; + // Join pg backend back. + pgb.unsplit(reply_reader.reader)?; + + res + } +} + +/// A half driving sending WAL. +struct WalSender<'a> { + pgb: &'a mut PostgresBackend, + tli: Arc, + appname: Option, + // Position since which we are sending next chunk. + start_pos: Lsn, + // WAL up to this position is known to be locally available. + end_pos: Lsn, + // If present, terminate after reaching this position; used by walproposer + // in recovery. + stop_pos: Option, + commit_lsn_watch_rx: Receiver, + replica_id: usize, + wal_reader: WalReader, + // buffer for readling WAL into to send it + send_buf: [u8; MAX_SEND_SIZE], +} + +impl WalSender<'_> { + /// Send WAL until + /// - an error occurs + /// - if we are streaming to walproposer, we've streamed until stop_pos + /// (recovery finished) + /// - receiver is caughtup and there is no computes + /// + /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? + /// convenience. + async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + // If we are streaming to walproposer, check it is time to stop. + if let Some(stop_pos) = self.stop_pos { + if self.start_pos >= stop_pos { + // recovery finished + return Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to walproposer at {}, recovery finished", + self.start_pos + ))); } - })?; - - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - runtime.block_on(async move { - let (inmem_state, persisted_state) = tli.get_state(); - // add persisted_state.timeline_start_lsn == Lsn(0) check - - // Walproposer gets special handling: safekeeper must give proposer all - // local WAL till the end, whether committed or not (walproposer will - // hang otherwise). That's because walproposer runs the consensus and - // synchronizes safekeepers on the most advanced one. - // - // There is a small risk of this WAL getting concurrently garbaged if - // another compute rises which collects majority and starts fixing log - // on this safekeeper itself. That's ok as (old) proposer will never be - // able to commit such WAL. - let stop_pos: Option = if spg.is_walproposer_recovery() { - let wal_end = tli.get_flush_lsn(); - Some(wal_end) } else { - None - }; + // Wait for the next portion if it is not there yet, or just + // update our end of WAL available for sending value, we + // communicate it to the receiver. + self.wait_wal().await?; + } - info!("Start replication from {:?} till {:?}", start_pos, stop_pos); + // try to send as much as available, capped by MAX_SEND_SIZE + let mut send_size = self + .end_pos + .checked_sub(self.start_pos) + .context("reading wal without waiting for it first")? + .0 as usize; + send_size = min(send_size, self.send_buf.len()); + let send_buf = &mut self.send_buf[..send_size]; + // read wal into buffer + send_size = self.wal_reader.read(send_buf).await?; + let send_buf = &send_buf[..send_size]; - // switch to copy - pgb.write_message(&BeMessage::CopyBothResponse)?; - - let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn); - - let mut wal_reader = WalReader::new( - spg.conf.workdir.clone(), - spg.conf.timeline_dir(&tli.ttid), - &persisted_state, - start_pos, - spg.conf.wal_backup_enabled, - )?; - - // buffer for wal sending, limited by MAX_SEND_SIZE - let mut send_buf = vec![0u8; MAX_SEND_SIZE]; - - // watcher for commit_lsn updates - let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx(); - - loop { - if let Some(stop_pos) = stop_pos { - if start_pos >= stop_pos { - break; /* recovery finished */ - } - end_pos = stop_pos; - } else { - /* Wait until we have some data to stream */ - let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?; - - if let Some(lsn) = lsn { - end_pos = lsn; - } else { - // TODO: also check once in a while whether we are walsender - // to right pageserver. - if tli.should_walsender_stop(replica_id) { - // Shut down, timeline is suspended. - return Err(QueryError::from(io::Error::new( - io::ErrorKind::ConnectionAborted, - format!("end streaming to {:?}", spg.appname), - ))); - } - - // timeout expired: request pageserver status - pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive { - sent_ptr: end_pos.0, - timestamp: get_current_timestamp(), - request_reply: true, - }))?; - continue; - } - } - - let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; - let send_size = min(send_size, send_buf.len()); - - let send_buf = &mut send_buf[..send_size]; - - // read wal into buffer - let send_size = wal_reader.read(send_buf).await?; - let send_buf = &send_buf[..send_size]; - - // Write some data to the network socket. - pgb.write_message(&BeMessage::XLogData(XLogDataBody { - wal_start: start_pos.0, - wal_end: end_pos.0, + // and send it + self.pgb + .write_message(&BeMessage::XLogData(XLogDataBody { + wal_start: self.start_pos.0, + wal_end: self.end_pos.0, timestamp: get_current_timestamp(), data: send_buf, })) - .context("Failed to send XLogData")?; + .await?; - start_pos += send_size as u64; - trace!("sent WAL up to {}", start_pos); + trace!( + "sent {} bytes of WAL {}-{}", + send_size, + self.start_pos, + self.start_pos + send_size as u64 + ); + self.start_pos += send_size as u64; + } + } + + /// wait until we have WAL to stream, sending keepalives and checking for + /// exit in the meanwhile + async fn wait_wal(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + if let Some(lsn) = wait_for_lsn(&mut self.commit_lsn_watch_rx, self.start_pos).await? { + self.end_pos = lsn; + return Ok(()); } + // Timed out waiting for WAL, check for termination and send KA + if self.tli.should_walsender_stop(self.replica_id) { + // Terminate if there is nothing more to send. + // TODO close the stream properly + return Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to {:?} at {}, receiver is caughtup and there is no computes", + self.appname, self.start_pos, + ))); + } + self.pgb + .write_message(&BeMessage::KeepAlive(WalSndKeepAlive { + sent_ptr: self.end_pos.0, + timestamp: get_current_timestamp(), + request_reply: true, + })) + .await?; + } + } +} - Ok(()) - }) +/// A half driving receiving replies. +struct ReplyReader { + reader: PostgresBackendReader, + tli: Arc, + replica_id: usize, + feedback: ReplicaState, +} + +impl ReplyReader { + async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + loop { + let msg = self.reader.read_copy_message().await?; + self.handle_feedback(&msg)? + } + } + + fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> { + match msg.first().cloned() { + Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => { + // Note: deserializing is on m[1..] because we skip the tag byte. + self.feedback.hs_feedback = HotStandbyFeedback::des(&msg[1..]) + .context("failed to deserialize HotStandbyFeedback")?; + self.tli + .update_replica_state(self.replica_id, self.feedback); + } + Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => { + let _reply = + StandbyReply::des(&msg[1..]).context("failed to deserialize StandbyReply")?; + // This must be a regular postgres replica, + // because pageserver doesn't send this type of messages to safekeeper. + // Currently we just ignore this, tracking progress for them is not supported. + } + Some(NEON_STATUS_UPDATE_TAG_BYTE) => { + // pageserver sends this. + // Note: deserializing is on m[9..] because we skip the tag byte and len bytes. + let buf = Bytes::copy_from_slice(&msg[9..]); + let reply = ReplicationFeedback::parse(buf); + + trace!("ReplicationFeedback is {:?}", reply); + // Only pageserver sends ReplicationFeedback, so set the flag. + // This replica is the source of information to resend to compute. + self.feedback.pageserver_feedback = Some(reply); + + self.tli + .update_replica_state(self.replica_id, self.feedback); + } + _ => warn!("unexpected message {:?}", msg), + } + Ok(()) } } const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); -// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn. +/// Wait until we have commit_lsn > lsn or timeout expires. Returns +/// - Ok(Some(commit_lsn)) if needed lsn is successfully observed; +/// - Ok(None) if timeout expired; +/// - Err in case of error (if watch channel is in trouble, shouldn't happen). async fn wait_for_lsn(rx: &mut Receiver, lsn: Lsn) -> anyhow::Result> { let commit_lsn: Lsn = *rx.borrow(); if commit_lsn > lsn { diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 43c395574f..d5b849019e 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -1,4 +1,4 @@ -//! This module implements Timeline lifecycle management and has all neccessary code +//! This module implements Timeline lifecycle management and has all necessary code //! to glue together SafeKeeper and all other background services. use anyhow::{bail, Result}; @@ -518,7 +518,7 @@ impl Timeline { /// Register compute connection, starting timeline-related activity if it is /// not running yet. - pub fn on_compute_connect(&self) -> Result<()> { + pub async fn on_compute_connect(&self) -> Result<()> { if self.is_cancelled() { bail!(TimelineError::Cancelled(self.ttid)); } @@ -532,7 +532,7 @@ impl Timeline { // Wake up wal backup launcher, if offloading not started yet. if is_wal_backup_action_pending { // Can fail only if channel to a static thread got closed, which is not normal at all. - self.wal_backup_launcher_tx.blocking_send(self.ttid)?; + self.wal_backup_launcher_tx.send(self.ttid).await?; } Ok(()) } @@ -549,6 +549,11 @@ impl Timeline { // Wake up wal backup launcher, if it is time to stop the offloading. if is_wal_backup_action_pending { // Can fail only if channel to a static thread got closed, which is not normal at all. + // + // Note: this is blocking_send because on_compute_disconnect is called in Drop, there is + // no async Drop and we use current thread runtimes. With current thread rt spawning + // task in drop impl is racy, as thread along with runtime might finish before the task. + // This should be switched send.await when/if we go to full async. self.wal_backup_launcher_tx.blocking_send(self.ttid)?; } Ok(()) diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index 66e0145042..fcf5ab6302 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -161,7 +161,7 @@ impl GlobalTimelines { /// Create a new timeline with the given id. If the timeline already exists, returns /// an existing timeline. - pub fn create( + pub async fn create( ttid: TenantTimelineId, server_info: ServerInfo, commit_lsn: Lsn, @@ -189,28 +189,20 @@ impl GlobalTimelines { // Take a lock and finish the initialization holding this mutex. No other threads // can interfere with creation after we will insert timeline into the map. - let mut shared_state = timeline.write_shared_state(); + { + let mut shared_state = timeline.write_shared_state(); - // We can get a race condition here in case of concurrent create calls, but only - // in theory. create() will return valid timeline on the next try. - TIMELINES_STATE - .lock() - .unwrap() - .try_insert(timeline.clone())?; + // We can get a race condition here in case of concurrent create calls, but only + // in theory. create() will return valid timeline on the next try. + TIMELINES_STATE + .lock() + .unwrap() + .try_insert(timeline.clone())?; - // Write the new timeline to the disk and start background workers. - // Bootstrap is transactional, so if it fails, the timeline will be deleted, - // and the state on disk should remain unchanged. - match timeline.bootstrap(&mut shared_state) { - Ok(_) => { - // We are done with bootstrap, release the lock, return the timeline. - drop(shared_state); - timeline - .wal_backup_launcher_tx - .blocking_send(timeline.ttid)?; - Ok(timeline) - } - Err(e) => { + // Write the new timeline to the disk and start background workers. + // Bootstrap is transactional, so if it fails, the timeline will be deleted, + // and the state on disk should remain unchanged. + if let Err(e) = timeline.bootstrap(&mut shared_state) { // Note: the most likely reason for bootstrap failure is that the timeline // directory already exists on disk. This happens when timeline is corrupted // and wasn't loaded from disk on startup because of that. We want to preserve @@ -222,9 +214,13 @@ impl GlobalTimelines { // Timeline failed to bootstrap, it cannot be used. Remove it from the map. TIMELINES_STATE.lock().unwrap().timelines.remove(&ttid); - Err(e) + return Err(e); } + // We are done with bootstrap, release the lock, return the timeline. + // {} block forces release before .await } + timeline.wal_backup_launcher_tx.send(timeline.ttid).await?; + Ok(timeline) } /// Get a timeline from the global map. If it's not present, it doesn't exist on disk, @@ -244,7 +240,7 @@ impl GlobalTimelines { } } - /// Returns all timelines. This is used for background timeline proccesses. + /// Returns all timelines. This is used for background timeline processes. pub fn get_all() -> Vec> { let global_lock = TIMELINES_STATE.lock().unwrap(); global_lock diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index fc971ca753..798b9abaf3 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -191,7 +191,7 @@ async fn wal_backup_launcher_main_loop( .map(|c| GenericRemoteStorage::from_config(c).expect("failed to create remote storage")) }); - // Presense in this map means launcher is aware s3 offloading is needed for + // Presence in this map means launcher is aware s3 offloading is needed for // the timeline, but task is started only if it makes sense for to offload // from this safekeeper. let mut tasks: HashMap = HashMap::new(); @@ -467,7 +467,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize pub async fn read_object( file_path: &RemotePath, offset: u64, -) -> anyhow::Result>> { +) -> anyhow::Result>> { let storage = REMOTE_STORAGE .get() .context("Failed to get remote storage")? diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 3ca651d060..8d63d604ad 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -2,50 +2,65 @@ //! WAL service listens for client connections and //! receive WAL from wal_proposer and send it to WAL receivers //! -use regex::Regex; -use std::net::{TcpListener, TcpStream}; -use std::thread; +use anyhow::{Context, Result}; +use nix::unistd::gettid; +use postgres_backend::QueryError; +use std::{future, thread}; +use tokio::net::TcpStream; use tracing::*; -use utils::postgres_backend_async::QueryError; use crate::handler::SafekeeperPostgresHandler; use crate::SafeKeeperConf; -use utils::postgres_backend::{AuthType, PostgresBackend}; +use postgres_backend::{AuthType, PostgresBackend}; /// Accept incoming TCP connections and spawn them into a background thread. -pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! { - loop { - match listener.accept() { - Ok((socket, peer_addr)) => { - debug!("accepted connection from {}", peer_addr); - let conf = conf.clone(); +pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .context("create runtime") + // todo catch error in main thread + .expect("failed to create runtime"); - let _ = thread::Builder::new() - .name("WAL service thread".into()) - .spawn(move || { - if let Err(err) = handle_socket(socket, conf) { - error!("connection handler exited: {}", err); - } - }) - .unwrap(); + runtime + .block_on(async move { + // Tokio's from_std won't do this for us, per its comment. + pg_listener.set_nonblocking(true)?; + let listener = tokio::net::TcpListener::from_std(pg_listener)?; + + loop { + match listener.accept().await { + Ok((socket, peer_addr)) => { + debug!("accepted connection from {}", peer_addr); + let conf = conf.clone(); + + let _ = thread::Builder::new() + .name("WAL service thread".into()) + .spawn(move || { + if let Err(err) = handle_socket(socket, conf) { + error!("connection handler exited: {}", err); + } + }) + .unwrap(); + } + Err(e) => error!("Failed to accept connection: {}", e), + } } - Err(e) => error!("Failed to accept connection: {}", e), - } - } -} - -// Get unique thread id (Rust internal), with ThreadId removed for shorter printing -fn get_tid() -> u64 { - let tids = format!("{:?}", thread::current().id()); - let r = Regex::new(r"ThreadId\((\d+)\)").unwrap(); - let caps = r.captures(&tids).unwrap(); - caps.get(1).unwrap().as_str().parse().unwrap() + #[allow(unreachable_code)] // hint compiler the closure return type + Ok::<(), anyhow::Error>(()) + }) + .expect("listener failed") } /// This is run by `thread_main` above, inside a background thread. /// fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> { - let _enter = info_span!("", tid = ?get_tid()).entered(); + let _enter = info_span!("", tid = %gettid()).entered(); + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + let local = tokio::task::LocalSet::new(); socket.set_nodelay(true)?; @@ -54,9 +69,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr Some(_) => AuthType::NeonJWT, }; let mut conn_handler = SafekeeperPostgresHandler::new(conf); - let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?; - // libpq replication protocol between safekeeper and replicas/pagers - pgbackend.run(&mut conn_handler)?; + let pgbackend = PostgresBackend::new(socket, auth_type, None)?; + // libpq protocol between safekeeper and walproposer / pageserver + // We don't use shutdown. + local.block_on( + &runtime, + pgbackend.run(&mut conn_handler, future::pending::<()>), + )?; Ok(()) } diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index 561104bd27..e83f72a3cf 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -461,7 +461,7 @@ pub struct WalReader { timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn, - wal_segment: Option>>, + wal_segment: Option>>, // S3 will be used to read WAL if LSN is not available locally enable_remote_read: bool, @@ -528,7 +528,7 @@ impl WalReader { } /// Open WAL segment at the current position of the reader. - async fn open_segment(&self) -> Result>> { + async fn open_segment(&self) -> Result>> { let xlogoff = self.pos.segment_offset(self.wal_seg_size); let segno = self.pos.segment_number(self.wal_seg_size); let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size); diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index c4b3d057f8..611361eaf1 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2049,8 +2049,10 @@ class NeonPageserver(PgProtocol): ".*Connection aborted: connection error: error communicating with the server: Broken pipe.*", ".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*", ".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*", + # FIXME: replication patch for tokio_postgres regards any but CopyDone/CopyData message in CopyBoth stream as unexpected + ".*Connection aborted: connection error: unexpected message from server*", ".*kill_and_wait_impl.*: wait successful.*", - ".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*", + ".*Replication stream finished: db error:.*ending streaming to Some*", ".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down ".*query handler for 'pagestream.*failed: Connection reset by peer.*", # pageserver notices compute shut down # safekeeper connection can fail with this, in the window between timeline creation diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 9e3b0ec02f..cea2125f4f 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -1106,8 +1106,8 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): # FIXME: are these expected? env.pageserver.allowed_errors.extend( [ - ".*Failed to process query for timeline .*: Timeline .* was not found in global map.*", - ".*Failed to process query for timeline .*: Timeline .* was cancelled and cannot be used anymore.*", + ".*Timeline .* was not found in global map.*", + ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index bd21095fff..28e8e4149c 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -14,14 +14,19 @@ publish = false ### BEGIN HAKARI SECTION [dependencies] anyhow = { version = "1", features = ["backtrace"] } +byteorder = { version = "1" } bytes = { version = "1", features = ["serde"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } clap = { version = "4", features = ["derive", "string"] } crossbeam-utils = { version = "0.8" } +digest = { version = "0.10", features = ["mac", "std"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures = { version = "0.3" } +futures-channel = { version = "0.3", features = ["sink"] } +futures-core = { version = "0.3" } futures-executor = { version = "0.3" } +futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } hashbrown = { version = "0.12", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] }