From 0fbe657b2f268351dc5daabee09754a578be3948 Mon Sep 17 00:00:00 2001 From: Alexey Kondratov Date: Wed, 13 Apr 2022 00:02:06 +0300 Subject: [PATCH 01/19] Fix remote e2e tests after repository rename (#1434) Also start them after release build instead of debug. It saves 3-5 minutes and we anyway use release mode in Docker images. --- .circleci/config.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e96964558b..9d26d5d558 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -672,7 +672,7 @@ jobs: --data \ "{ \"state\": \"pending\", - \"context\": \"zenith-remote-ci\", + \"context\": \"neon-cloud-e2e\", \"description\": \"[$REMOTE_REPO] Remote CI job is about to start\" }" - run: @@ -688,7 +688,7 @@ jobs: "{ \"ref\": \"main\", \"inputs\": { - \"ci_job_name\": \"zenith-remote-ci\", + \"ci_job_name\": \"neon-cloud-e2e\", \"commit_hash\": \"$CIRCLE_SHA1\", \"remote_repo\": \"$LOCAL_REPO\" } @@ -828,11 +828,11 @@ workflows: - remote-ci-trigger: # Context passes credentials for gh api context: CI_ACCESS_TOKEN - remote_repo: "zenithdb/console" + remote_repo: "neondatabase/cloud" requires: # XXX: Successful build doesn't mean everything is OK, but # the job to be triggered takes so much time to complete (~22 min) # that it's better not to wait for the commented-out steps - - build-zenith-debug + - build-zenith-release # - pg_regress-tests-release # - other-tests-release From 4af87f3d6097661c99cbf5b400c1af6c44819e43 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Wed, 13 Apr 2022 03:00:32 +0300 Subject: [PATCH 02/19] [proxy] Add SCRAM auth mechanism implementation (#1050) * [proxy] Add SCRAM auth * [proxy] Implement some tests for SCRAM * Refactoring + test fixes * Hide SCRAM mechanism behind `#[cfg(test)]` Currently we only use it in tests, so we hide all relevant module behind `#[cfg(test)]` to prevent "unused item" warnings. --- Cargo.lock | 35 +++- proxy/Cargo.toml | 11 +- proxy/src/auth.rs | 88 +++------- proxy/src/auth/credentials.rs | 70 ++++++++ proxy/src/auth/flow.rs | 102 ++++++++++++ proxy/src/main.rs | 39 +++-- proxy/src/parse.rs | 18 +++ proxy/src/proxy.rs | 229 ++++++++++++++++++++------ proxy/src/sasl.rs | 47 ++++++ proxy/src/sasl/channel_binding.rs | 85 ++++++++++ proxy/src/sasl/messages.rs | 67 ++++++++ proxy/src/sasl/stream.rs | 70 ++++++++ proxy/src/scram.rs | 59 +++++++ proxy/src/scram/exchange.rs | 134 ++++++++++++++++ proxy/src/scram/key.rs | 33 ++++ proxy/src/scram/messages.rs | 232 +++++++++++++++++++++++++++ proxy/src/scram/password.rs | 48 ++++++ proxy/src/scram/secret.rs | 116 ++++++++++++++ proxy/src/scram/signature.rs | 66 ++++++++ zenith_utils/src/postgres_backend.rs | 3 +- zenith_utils/src/pq_proto.rs | 36 ++++- 21 files changed, 1446 insertions(+), 142 deletions(-) create mode 100644 proxy/src/auth/credentials.rs create mode 100644 proxy/src/auth/flow.rs create mode 100644 proxy/src/parse.rs create mode 100644 proxy/src/sasl.rs create mode 100644 proxy/src/sasl/channel_binding.rs create mode 100644 proxy/src/sasl/messages.rs create mode 100644 proxy/src/sasl/stream.rs create mode 100644 proxy/src/scram.rs create mode 100644 proxy/src/scram/exchange.rs create mode 100644 proxy/src/scram/key.rs create mode 100644 proxy/src/scram/messages.rs create mode 100644 proxy/src/scram/password.rs create mode 100644 proxy/src/scram/secret.rs create mode 100644 proxy/src/scram/signature.rs diff --git a/Cargo.lock b/Cargo.lock index 1a9e261281..7df1c4ab7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1907,12 +1907,15 @@ name = "proxy" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "base64 0.13.0", "bytes", "clap 3.0.14", "fail", "futures", "hashbrown", "hex", + "hmac 0.10.1", "hyper", "lazy_static", "md5", @@ -1921,16 +1924,20 @@ dependencies = [ "rand", "rcgen", "reqwest", + "routerify 2.2.0", + "rstest", "rustls 0.19.1", "scopeguard", "serde", "serde_json", + "sha2", "socket2", "thiserror", "tokio", "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "tokio-postgres-rustls", "tokio-rustls 0.22.0", + "tokio-stream", "workspace_hack", "zenith_metrics", "zenith_utils", @@ -2130,6 +2137,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "routerify" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9" +dependencies = [ + "http", + "hyper", + "lazy_static", + "percent-encoding", + "regex", +] + [[package]] name = "routerify" version = "3.0.0" @@ -2143,6 +2163,19 @@ dependencies = [ "regex", ] +[[package]] +name = "rstest" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d912f35156a3f99a66ee3e11ac2e0b3f34ac85a07e05263d05a7e2c8810d616f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + [[package]] name = "rusoto_core" version = "0.47.0" @@ -3450,7 +3483,7 @@ dependencies = [ "postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "rand", - "routerify", + "routerify 3.0.0", "rustls 0.19.1", "rustls-split", "serde", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index dc20695884..56b6dd7e20 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -5,12 +5,14 @@ edition = "2021" [dependencies] anyhow = "1.0" +base64 = "0.13.0" bytes = { version = "1.0.1", features = ['serde'] } clap = "3.0" fail = "0.5.0" futures = "0.3.13" hashbrown = "0.11.2" hex = "0.4.3" +hmac = "0.10.1" hyper = "0.14" lazy_static = "1.4.0" md5 = "0.7.0" @@ -18,20 +20,25 @@ parking_lot = "0.11.2" pin-project-lite = "0.2.7" rand = "0.8.3" reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } +routerify = "2" rustls = "0.19.1" scopeguard = "1.1.0" serde = "1" serde_json = "1" +sha2 = "0.9.8" socket2 = "0.4.4" -thiserror = "1.0" +thiserror = "1.0.30" tokio = { version = "1.17", features = ["macros"] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } tokio-rustls = "0.22.0" +tokio-stream = "0.1.8" zenith_utils = { path = "../zenith_utils" } zenith_metrics = { path = "../zenith_metrics" } workspace_hack = { version = "0.1", path = "../workspace_hack" } [dev-dependencies] -tokio-postgres-rustls = "0.8.0" +async-trait = "0.1" rcgen = "0.8.14" +rstest = "0.12" +tokio-postgres-rustls = "0.8.0" diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index e8fe65c081..bda14d67a1 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,14 +1,24 @@ +mod credentials; + +#[cfg(test)] +mod flow; + use crate::compute::DatabaseInfo; use crate::config::ProxyConfig; use crate::cplane_api::{self, CPlaneApi}; use crate::error::UserFacingError; use crate::stream::PqStream; use crate::waiters; -use std::collections::HashMap; +use std::io; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; +pub use credentials::ClientCredentials; + +#[cfg(test)] +pub use flow::*; + /// Common authentication error. #[derive(Debug, Error)] pub enum AuthErrorImpl { @@ -16,13 +26,17 @@ pub enum AuthErrorImpl { #[error(transparent)] Console(#[from] cplane_api::AuthError), + #[cfg(test)] + #[error(transparent)] + Sasl(#[from] crate::sasl::Error), + /// For passwords that couldn't be processed by [`parse_password`]. #[error("Malformed password message")] MalformedPassword, /// Errors produced by [`PqStream`]. #[error(transparent)] - Io(#[from] std::io::Error), + Io(#[from] io::Error), } impl AuthErrorImpl { @@ -67,70 +81,6 @@ impl UserFacingError for AuthError { } } -#[derive(Debug, Error)] -pub enum ClientCredsParseError { - #[error("Parameter `{0}` is missing in startup packet")] - MissingKey(&'static str), -} - -impl UserFacingError for ClientCredsParseError {} - -/// Various client credentials which we use for authentication. -#[derive(Debug, PartialEq, Eq)] -pub struct ClientCredentials { - pub user: String, - pub dbname: String, -} - -impl TryFrom> for ClientCredentials { - type Error = ClientCredsParseError; - - fn try_from(mut value: HashMap) -> Result { - let mut get_param = |key| { - value - .remove(key) - .ok_or(ClientCredsParseError::MissingKey(key)) - }; - - let user = get_param("user")?; - let db = get_param("database")?; - - Ok(Self { user, dbname: db }) - } -} - -impl ClientCredentials { - /// Use credentials to authenticate the user. - pub async fn authenticate( - self, - config: &ProxyConfig, - client: &mut PqStream, - ) -> Result { - fail::fail_point!("proxy-authenticate", |_| { - Err(AuthError::auth_failed("failpoint triggered")) - }); - - use crate::config::ClientAuthMethod::*; - use crate::config::RouterConfig::*; - match &config.router_config { - Static { host, port } => handle_static(host.clone(), *port, client, self).await, - Dynamic(Mixed) => { - if self.user.ends_with("@zenith") { - handle_existing_user(config, client, self).await - } else { - handle_new_user(config, client).await - } - } - Dynamic(Password) => handle_existing_user(config, client, self).await, - Dynamic(Link) => handle_new_user(config, client).await, - } - } -} - -fn new_psql_session_id() -> String { - hex::encode(rand::random::<[u8; 8]>()) -} - async fn handle_static( host: String, port: u16, @@ -169,7 +119,7 @@ async fn handle_existing_user( let md5_salt = rand::random(); client - .write_message(&Be::AuthenticationMD5Password(&md5_salt)) + .write_message(&Be::AuthenticationMD5Password(md5_salt)) .await?; // Read client's password hash @@ -213,6 +163,10 @@ async fn handle_new_user( Ok(db_info) } +fn new_psql_session_id() -> String { + hex::encode(rand::random::<[u8; 8]>()) +} + fn parse_password(bytes: &[u8]) -> Option<&str> { std::str::from_utf8(bytes).ok()?.strip_suffix('\0') } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs new file mode 100644 index 0000000000..7c8ba28622 --- /dev/null +++ b/proxy/src/auth/credentials.rs @@ -0,0 +1,70 @@ +//! User credentials used in authentication. + +use super::AuthError; +use crate::compute::DatabaseInfo; +use crate::config::ProxyConfig; +use crate::error::UserFacingError; +use crate::stream::PqStream; +use std::collections::HashMap; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[derive(Debug, Error)] +pub enum ClientCredsParseError { + #[error("Parameter `{0}` is missing in startup packet")] + MissingKey(&'static str), +} + +impl UserFacingError for ClientCredsParseError {} + +/// Various client credentials which we use for authentication. +#[derive(Debug, PartialEq, Eq)] +pub struct ClientCredentials { + pub user: String, + pub dbname: String, +} + +impl TryFrom> for ClientCredentials { + type Error = ClientCredsParseError; + + fn try_from(mut value: HashMap) -> Result { + let mut get_param = |key| { + value + .remove(key) + .ok_or(ClientCredsParseError::MissingKey(key)) + }; + + let user = get_param("user")?; + let db = get_param("database")?; + + Ok(Self { user, dbname: db }) + } +} + +impl ClientCredentials { + /// Use credentials to authenticate the user. + pub async fn authenticate( + self, + config: &ProxyConfig, + client: &mut PqStream, + ) -> Result { + fail::fail_point!("proxy-authenticate", |_| { + Err(AuthError::auth_failed("failpoint triggered")) + }); + + use crate::config::ClientAuthMethod::*; + use crate::config::RouterConfig::*; + match &config.router_config { + Static { host, port } => super::handle_static(host.clone(), *port, client, self).await, + Dynamic(Mixed) => { + if self.user.ends_with("@zenith") { + super::handle_existing_user(config, client, self).await + } else { + super::handle_new_user(config, client).await + } + } + Dynamic(Password) => super::handle_existing_user(config, client, self).await, + Dynamic(Link) => super::handle_new_user(config, client).await, + } + } +} diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs new file mode 100644 index 0000000000..0fafaa2f47 --- /dev/null +++ b/proxy/src/auth/flow.rs @@ -0,0 +1,102 @@ +//! Main authentication flow. + +use super::{AuthError, AuthErrorImpl}; +use crate::stream::PqStream; +use crate::{sasl, scram}; +use std::io; +use tokio::io::{AsyncRead, AsyncWrite}; +use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; + +/// Every authentication selector is supposed to implement this trait. +pub trait AuthMethod { + /// Any authentication selector should provide initial backend message + /// containing auth method name and parameters, e.g. md5 salt. + fn first_message(&self) -> BeMessage<'_>; +} + +/// Initial state of [`AuthFlow`]. +pub struct Begin; + +/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. +pub struct Scram<'a>(pub &'a scram::ServerSecret); + +impl AuthMethod for Scram<'_> { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + } +} + +/// Use password-based auth in [`AuthFlow`]. +pub struct Md5( + /// Salt for client. + pub [u8; 4], +); + +impl AuthMethod for Md5 { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationMD5Password(self.0) + } +} + +/// This wrapper for [`PqStream`] performs client authentication. +#[must_use] +pub struct AuthFlow<'a, Stream, State> { + /// The underlying stream which implements libpq's protocol. + stream: &'a mut PqStream, + /// State might contain ancillary data (see [`AuthFlow::begin`]). + state: State, +} + +/// Initial state of the stream wrapper. +impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { + /// Create a new wrapper for client authentication. + pub fn new(stream: &'a mut PqStream) -> Self { + Self { + stream, + state: Begin, + } + } + + /// Move to the next step by sending auth method's name & params to client. + pub async fn begin(self, method: M) -> io::Result> { + self.stream.write_message(&method.first_message()).await?; + + Ok(AuthFlow { + stream: self.stream, + state: method, + }) + } +} + +/// Stream wrapper for handling simple MD5 password auth. +impl AuthFlow<'_, S, Md5> { + /// Perform user authentication. Raise an error in case authentication failed. + #[allow(unused)] + pub async fn authenticate(self) -> Result<(), AuthError> { + unimplemented!("MD5 auth flow is yet to be implemented"); + } +} + +/// Stream wrapper for handling [SCRAM](crate::scram) auth. +impl AuthFlow<'_, S, Scram<'_>> { + /// Perform user authentication. Raise an error in case authentication failed. + pub async fn authenticate(self) -> Result<(), AuthError> { + // Initial client message contains the chosen auth method's name. + let msg = self.stream.read_password_message().await?; + let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; + + // Currently, the only supported SASL method is SCRAM. + if !scram::METHODS.contains(&sasl.method) { + return Err(AuthErrorImpl::auth_failed("method not supported").into()); + } + + let secret = self.state.0; + sasl::SaslStream::new(self.stream, sasl.message) + .authenticate(scram::Exchange::new(secret, rand::random, None)) + .await?; + + Ok(()) + } +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index bd99d0a639..862152bb7b 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -1,19 +1,8 @@ -/// -/// Postgres protocol proxy/router. -/// -/// This service listens psql port and can check auth via external service -/// (control plane API in our case) and can create new databases and accounts -/// in somewhat transparent manner (again via communication with control plane API). -/// -use anyhow::{bail, Context}; -use clap::{App, Arg}; -use config::ProxyConfig; -use futures::FutureExt; -use std::future::Future; -use tokio::{net::TcpListener, task::JoinError}; -use zenith_utils::GIT_VERSION; - -use crate::config::{ClientAuthMethod, RouterConfig}; +//! Postgres protocol proxy/router. +//! +//! This service listens psql port and can check auth via external service +//! (control plane API in our case) and can create new databases and accounts +//! in somewhat transparent manner (again via communication with control plane API). mod auth; mod cancellation; @@ -27,6 +16,24 @@ mod proxy; mod stream; mod waiters; +// Currently SCRAM is only used in tests +#[cfg(test)] +mod parse; +#[cfg(test)] +mod sasl; +#[cfg(test)] +mod scram; + +use anyhow::{bail, Context}; +use clap::{App, Arg}; +use config::ProxyConfig; +use futures::FutureExt; +use std::future::Future; +use tokio::{net::TcpListener, task::JoinError}; +use zenith_utils::GIT_VERSION; + +use crate::config::{ClientAuthMethod, RouterConfig}; + /// Flattens `Result>` into `Result`. async fn flatten_err( f: impl Future, JoinError>>, diff --git a/proxy/src/parse.rs b/proxy/src/parse.rs new file mode 100644 index 0000000000..8a05ff9c82 --- /dev/null +++ b/proxy/src/parse.rs @@ -0,0 +1,18 @@ +//! Small parsing helpers. + +use std::convert::TryInto; +use std::ffi::CStr; + +pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { + let pos = bytes.iter().position(|&x| x == 0)?; + let (cstr, other) = bytes.split_at(pos + 1); + // SAFETY: we've already checked that there's a terminator + Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other)) +} + +pub fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { + (bytes.len() >= N).then(|| { + let (head, tail) = bytes.split_at(N); + (head.try_into().unwrap(), tail) + }) +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 81581b5cf1..5b662f4c69 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -119,7 +119,6 @@ async fn handshake( // We can't perform TLS handshake without a config let enc = tls.is_some(); stream.write_message(&Be::EncryptionResponse(enc)).await?; - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. @@ -219,32 +218,14 @@ impl Client { #[cfg(test)] mod tests { use super::*; - - use tokio::io::DuplexStream; + use crate::{auth, scram}; + use async_trait::async_trait; + use rstest::rstest; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres_rustls::MakeRustlsConnect; - async fn dummy_proxy( - client: impl AsyncRead + AsyncWrite + Unpin, - tls: Option, - ) -> anyhow::Result<()> { - let cancel_map = CancelMap::default(); - - // TODO: add some infra + tests for credentials - let (mut stream, _creds) = handshake(client, tls, &cancel_map) - .await? - .context("no stream")?; - - stream - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? - .write_message(&BeMessage::ReadyForQuery) - .await?; - - Ok(()) - } - + /// Generate a set of TLS certificates: CA + server. fn generate_certs( hostname: &str, ) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> { @@ -262,19 +243,115 @@ mod tests { )) } + struct ClientConfig<'a> { + config: rustls::ClientConfig, + hostname: &'a str, + } + + impl ClientConfig<'_> { + fn make_tls_connect( + self, + ) -> anyhow::Result> { + let mut mk = MakeRustlsConnect::new(self.config); + let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; + Ok(tls) + } + } + + /// Generate TLS certificates and build rustls configs for client and server. + fn generate_tls_config( + hostname: &str, + ) -> anyhow::Result<(ClientConfig<'_>, Arc)> { + let (ca, cert, key) = generate_certs(hostname)?; + + let server_config = { + let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(vec![cert], key)?; + config.into() + }; + + let client_config = { + let mut config = rustls::ClientConfig::new(); + config.root_store.add(&ca)?; + ClientConfig { config, hostname } + }; + + Ok((client_config, server_config)) + } + + #[async_trait] + trait TestAuth: Sized { + async fn authenticate( + self, + _stream: &mut PqStream>, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + struct NoAuth; + impl TestAuth for NoAuth {} + + struct Scram(scram::ServerSecret); + + impl Scram { + fn new(password: &str) -> anyhow::Result { + let salt = rand::random::<[u8; 16]>(); + let secret = scram::ServerSecret::build(password, &salt, 256) + .context("failed to generate scram secret")?; + Ok(Scram(secret)) + } + + fn mock(user: &str) -> Self { + let salt = rand::random::<[u8; 32]>(); + Scram(scram::ServerSecret::mock(user, &salt)) + } + } + + #[async_trait] + impl TestAuth for Scram { + async fn authenticate( + self, + stream: &mut PqStream>, + ) -> anyhow::Result<()> { + auth::AuthFlow::new(stream) + .begin(auth::Scram(&self.0)) + .await? + .authenticate() + .await?; + + Ok(()) + } + } + + /// A dummy proxy impl which performs a handshake and reports auth success. + async fn dummy_proxy( + client: impl AsyncRead + AsyncWrite + Unpin + Send, + tls: Option, + auth: impl TestAuth + Send, + ) -> anyhow::Result<()> { + let cancel_map = CancelMap::default(); + let (mut stream, _creds) = handshake(client, tls, &cancel_map) + .await? + .context("handshake failed")?; + + auth.authenticate(&mut stream).await?; + + stream + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message(&BeMessage::ReadyForQuery) + .await?; + + Ok(()) + } + #[tokio::test] async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let server_config = { - let (_ca, cert, key) = generate_certs("localhost")?; - - let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(vec![cert], key)?; - config - }; - - let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); + let (_, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let client_err = tokio_postgres::Config::new() .user("john_doe") @@ -301,30 +378,14 @@ mod tests { async fn handshake_tls() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (ca, cert, key) = generate_certs("localhost")?; - - let server_config = { - let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(vec![cert], key)?; - config - }; - - let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); - - let client_config = { - let mut config = rustls::ClientConfig::new(); - config.root_store.add(&ca)?; - config - }; - - let mut mk = MakeRustlsConnect::new(client_config); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk, "localhost")?; + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) - .connect_raw(server, tls) + .connect_raw(server, client_config.make_tls_connect()?) .await?; proxy.await? @@ -334,7 +395,7 @@ mod tests { async fn handshake_raw() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let proxy = tokio::spawn(dummy_proxy(client, None)); + let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() .user("john_doe") @@ -350,7 +411,7 @@ mod tests { async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let proxy = tokio::spawn(dummy_proxy(client, None)); + let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); let client_err = tokio_postgres::Config::new() .ssl_mode(SslMode::Disable) @@ -391,4 +452,66 @@ mod tests { Ok(()) } + + #[rstest] + #[case("password_foo")] + #[case("pwd-bar")] + #[case("")] + #[tokio::test] + async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::new(password)?, + )); + + let (_client, _conn) = tokio_postgres::Config::new() + .user("user") + .dbname("db") + .password(password) + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await?; + + proxy.await? + } + + #[tokio::test] + async fn scram_auth_mock() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::mock("user"), + )); + + use rand::{distributions::Alphanumeric, Rng}; + let password: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(rand::random::() as usize) + .map(char::from) + .collect(); + + let _client_err = tokio_postgres::Config::new() + .user("user") + .dbname("db") + .password(&password) // no password will match the mocked secret + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await + .err() // -> Option + .context("client shouldn't be able to connect")?; + + let _server_err = proxy + .await? + .err() // -> Option + .context("server shouldn't accept client")?; + + Ok(()) + } } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs new file mode 100644 index 0000000000..70a4d9946a --- /dev/null +++ b/proxy/src/sasl.rs @@ -0,0 +1,47 @@ +//! Simple Authentication and Security Layer. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +mod channel_binding; +mod messages; +mod stream; + +use std::io; +use thiserror::Error; + +pub use channel_binding::ChannelBinding; +pub use messages::FirstMessage; +pub use stream::SaslStream; + +/// Fine-grained auth errors help in writing tests. +#[derive(Error, Debug)] +pub enum Error { + #[error("Failed to authenticate client: {0}")] + AuthenticationFailed(&'static str), + + #[error("Channel binding failed: {0}")] + ChannelBindingFailed(&'static str), + + #[error("Unsupported channel binding method: {0}")] + ChannelBindingBadMethod(Box), + + #[error("Bad client message")] + BadClientMessage, + + #[error(transparent)] + Io(#[from] io::Error), +} + +/// A convenient result type for SASL exchange. +pub type Result = std::result::Result; + +/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. +pub trait Mechanism: Sized { + /// Produce a server challenge to be sent to the client. + /// This is how this method is called in PostgreSQL (`libpq/sasl.h`). + fn exchange(self, input: &str) -> Result<(Option, String)>; +} diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs new file mode 100644 index 0000000000..776adabe55 --- /dev/null +++ b/proxy/src/sasl/channel_binding.rs @@ -0,0 +1,85 @@ +//! Definition and parser for channel binding flag (a part of the `GS2` header). + +/// Channel binding flag (possibly with params). +#[derive(Debug, PartialEq, Eq)] +pub enum ChannelBinding { + /// Client doesn't support channel binding. + NotSupportedClient, + /// Client thinks server doesn't support channel binding. + NotSupportedServer, + /// Client wants to use this type of channel binding. + Required(T), +} + +impl ChannelBinding { + pub fn and_then(self, f: impl FnOnce(T) -> Result) -> Result, E> { + use ChannelBinding::*; + Ok(match self { + NotSupportedClient => NotSupportedClient, + NotSupportedServer => NotSupportedServer, + Required(x) => Required(f(x)?), + }) + } +} + +impl<'a> ChannelBinding<&'a str> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + use ChannelBinding::*; + Some(match input { + "n" => NotSupportedClient, + "y" => NotSupportedServer, + other => Required(other.strip_prefix("p=")?), + }) + } +} + +impl ChannelBinding { + /// Encode channel binding data as base64 for subsequent checks. + pub fn encode( + &self, + get_cbind_data: impl FnOnce(&T) -> Result, + ) -> Result, E> { + use ChannelBinding::*; + Ok(match self { + NotSupportedClient => { + // base64::encode("n,,") + "biws".into() + } + NotSupportedServer => { + // base64::encode("y,,") + "eSws".into() + } + Required(mode) => { + let msg = format!( + "p={mode},,{data}", + mode = mode, + data = get_cbind_data(mode)? + ); + base64::encode(msg).into() + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn channel_binding_encode() -> anyhow::Result<()> { + use ChannelBinding::*; + + let cases = [ + (NotSupportedClient, base64::encode("n,,")), + (NotSupportedServer, base64::encode("y,,")), + (Required("foo"), base64::encode("p=foo,,bar")), + ]; + + for (cb, input) in cases { + assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input); + } + + Ok(()) + } +} diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs new file mode 100644 index 0000000000..b1ae8cc426 --- /dev/null +++ b/proxy/src/sasl/messages.rs @@ -0,0 +1,67 @@ +//! Definitions for SASL messages. + +use crate::parse::{split_at_const, split_cstr}; +use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage}; + +/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage). +#[derive(Debug)] +pub struct FirstMessage<'a> { + /// Authentication method, e.g. `"SCRAM-SHA-256"`. + pub method: &'a str, + /// Initial client message. + pub message: &'a str, +} + +impl<'a> FirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(bytes: &'a [u8]) -> Option { + let (method_cstr, tail) = split_cstr(bytes)?; + let method = method_cstr.to_str().ok()?; + + let (len_bytes, bytes) = split_at_const(tail)?; + let len = u32::from_be_bytes(*len_bytes) as usize; + if len != bytes.len() { + return None; + } + + let message = std::str::from_utf8(bytes).ok()?; + Some(Self { method, message }) + } +} + +/// A single SASL message. +/// This struct is deliberately decoupled from lower-level +/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage). +#[derive(Debug)] +pub(super) enum ServerMessage { + /// We expect to see more steps. + Continue(T), + /// This is the final step. + Final(T), +} + +impl<'a> ServerMessage<&'a str> { + pub(super) fn to_reply(&self) -> BeMessage<'a> { + use BeAuthenticationSaslMessage::*; + BeMessage::AuthenticationSasl(match self { + ServerMessage::Continue(s) => Continue(s.as_bytes()), + ServerMessage::Final(s) => Final(s.as_bytes()), + }) + } +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_sasl_first_message() { + let proto = "SCRAM-SHA-256"; + let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4"; + let sasl_len = (sasl.len() as u32).to_be_bytes(); + let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat(); + + let password = FirstMessage::parse(&bytes).unwrap(); + assert_eq!(password.method, proto); + assert_eq!(password.message, sasl); + } +} diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs new file mode 100644 index 0000000000..03649b8d11 --- /dev/null +++ b/proxy/src/sasl/stream.rs @@ -0,0 +1,70 @@ +//! Abstraction for the string-oriented SASL protocols. + +use super::{messages::ServerMessage, Mechanism}; +use crate::stream::PqStream; +use std::io; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// Abstracts away all peculiarities of the libpq's protocol. +pub struct SaslStream<'a, S> { + /// The underlying stream. + stream: &'a mut PqStream, + /// Current password message we received from client. + current: bytes::Bytes, + /// First SASL message produced by client. + first: Option<&'a str>, +} + +impl<'a, S> SaslStream<'a, S> { + pub fn new(stream: &'a mut PqStream, first: &'a str) -> Self { + Self { + stream, + current: bytes::Bytes::new(), + first: Some(first), + } + } +} + +impl SaslStream<'_, S> { + // Receive a new SASL message from the client. + async fn recv(&mut self) -> io::Result<&str> { + if let Some(first) = self.first.take() { + return Ok(first); + } + + self.current = self.stream.read_password_message().await?; + let s = std::str::from_utf8(&self.current) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + + Ok(s) + } +} + +impl SaslStream<'_, S> { + // Send a SASL message to the client. + async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { + self.stream.write_message(&msg.to_reply()).await?; + Ok(()) + } +} + +impl SaslStream<'_, S> { + /// Perform SASL message exchange according to the underlying algorithm + /// until user is either authenticated or denied access. + pub async fn authenticate(mut self, mut mechanism: impl Mechanism) -> super::Result<()> { + loop { + let input = self.recv().await?; + let (moved, reply) = mechanism.exchange(input)?; + match moved { + Some(moved) => { + self.send(&ServerMessage::Continue(&reply)).await?; + mechanism = moved; + } + None => { + self.send(&ServerMessage::Final(&reply)).await?; + return Ok(()); + } + } + } + } +} diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs new file mode 100644 index 0000000000..f007f3e0b6 --- /dev/null +++ b/proxy/src/scram.rs @@ -0,0 +1,59 @@ +//! Salted Challenge Response Authentication Mechanism. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +mod exchange; +mod key; +mod messages; +mod password; +mod secret; +mod signature; + +pub use secret::*; + +pub use exchange::Exchange; +pub use secret::ServerSecret; + +use hmac::{Hmac, Mac, NewMac}; +use sha2::{Digest, Sha256}; + +// TODO: add SCRAM-SHA-256-PLUS +/// A list of supported SCRAM methods. +pub const METHODS: &[&str] = &["SCRAM-SHA-256"]; + +/// Decode base64 into array without any heap allocations +fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { + let mut bytes = [0u8; N]; + + let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?; + if size != N { + return None; + } + + Some(bytes) +} + +/// This function essentially is `Hmac(sha256, key, input)`. +/// Further reading: . +fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator) -> [u8; 32] { + let mut mac = Hmac::::new_varkey(key).expect("bad key size"); + parts.into_iter().for_each(|s| mac.update(s)); + + // TODO: maybe newer `hmac` et al already migrated to regular arrays? + let mut result = [0u8; 32]; + result.copy_from_slice(mac.finalize().into_bytes().as_slice()); + result +} + +fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { + let mut hasher = Sha256::new(); + parts.into_iter().for_each(|s| hasher.update(s)); + + let mut result = [0u8; 32]; + result.copy_from_slice(hasher.finalize().as_slice()); + result +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs new file mode 100644 index 0000000000..5a986b965a --- /dev/null +++ b/proxy/src/scram/exchange.rs @@ -0,0 +1,134 @@ +//! Implementation of the SCRAM authentication algorithm. + +use super::messages::{ + ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, +}; +use super::secret::ServerSecret; +use super::signature::SignatureBuilder; +use crate::sasl::{self, ChannelBinding, Error as SaslError}; + +/// The only channel binding mode we currently support. +#[derive(Debug)] +struct TlsServerEndPoint; + +impl std::fmt::Display for TlsServerEndPoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "tls-server-end-point") + } +} + +impl std::str::FromStr for TlsServerEndPoint { + type Err = sasl::Error; + + fn from_str(s: &str) -> Result { + match s { + "tls-server-end-point" => Ok(TlsServerEndPoint), + _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())), + } + } +} + +#[derive(Debug)] +enum ExchangeState { + /// Waiting for [`ClientFirstMessage`]. + Initial, + /// Waiting for [`ClientFinalMessage`]. + SaltSent { + cbind_flag: ChannelBinding, + client_first_message_bare: String, + server_first_message: OwnedServerFirstMessage, + }, +} + +/// Server's side of SCRAM auth algorithm. +#[derive(Debug)] +pub struct Exchange<'a> { + state: ExchangeState, + secret: &'a ServerSecret, + nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], + cert_digest: Option<&'a [u8]>, +} + +impl<'a> Exchange<'a> { + pub fn new( + secret: &'a ServerSecret, + nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], + cert_digest: Option<&'a [u8]>, + ) -> Self { + Self { + state: ExchangeState::Initial, + secret, + nonce, + cert_digest, + } + } +} + +impl sasl::Mechanism for Exchange<'_> { + fn exchange(mut self, input: &str) -> sasl::Result<(Option, String)> { + use ExchangeState::*; + match &self.state { + Initial => { + let client_first_message = + ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let server_first_message = client_first_message.build_server_first_message( + &(self.nonce)(), + &self.secret.salt_base64, + self.secret.iterations, + ); + let msg = server_first_message.as_str().to_owned(); + + self.state = SaltSent { + cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?, + client_first_message_bare: client_first_message.bare.to_owned(), + server_first_message, + }; + + Ok((Some(self), msg)) + } + SaltSent { + cbind_flag, + client_first_message_bare, + server_first_message, + } => { + let client_final_message = + ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let channel_binding = cbind_flag.encode(|_| { + self.cert_digest + .map(base64::encode) + .ok_or(SaslError::ChannelBindingFailed("no cert digest provided")) + })?; + + // This might've been caused by a MITM attack + if client_final_message.channel_binding != channel_binding { + return Err(SaslError::ChannelBindingFailed("data mismatch")); + } + + if client_final_message.nonce != server_first_message.nonce() { + return Err(SaslError::AuthenticationFailed("bad nonce")); + } + + let signature_builder = SignatureBuilder { + client_first_message_bare, + server_first_message: server_first_message.as_str(), + client_final_message_without_proof: client_final_message.without_proof, + }; + + let client_key = signature_builder + .build(&self.secret.stored_key) + .derive_client_key(&client_final_message.proof); + + if client_key.sha256() != self.secret.stored_key { + return Err(SaslError::AuthenticationFailed("keys don't match")); + } + + let msg = client_final_message + .build_server_final_message(signature_builder, &self.secret.server_key); + + Ok((None, msg)) + } + } + } +} diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs new file mode 100644 index 0000000000..1c13471bc3 --- /dev/null +++ b/proxy/src/scram/key.rs @@ -0,0 +1,33 @@ +//! Tools for client/server/stored key management. + +/// Faithfully taken from PostgreSQL. +pub const SCRAM_KEY_LEN: usize = 32; + +/// One of the keys derived from the [password](super::password::SaltedPassword). +/// We use the same structure for all keys, i.e. +/// `ClientKey`, `StoredKey`, and `ServerKey`. +#[derive(Default, Debug, PartialEq, Eq)] +#[repr(transparent)] +pub struct ScramKey { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl ScramKey { + pub fn sha256(&self) -> Self { + super::sha256([self.as_ref()]).into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { + #[inline(always)] + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for ScramKey { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs new file mode 100644 index 0000000000..f6e6133adf --- /dev/null +++ b/proxy/src/scram/messages.rs @@ -0,0 +1,232 @@ +//! Definitions for SCRAM messages. + +use super::base64_decode_array; +use super::key::{ScramKey, SCRAM_KEY_LEN}; +use super::signature::SignatureBuilder; +use crate::sasl::ChannelBinding; +use std::fmt; +use std::ops::Range; + +/// Faithfully taken from PostgreSQL. +pub const SCRAM_RAW_NONCE_LEN: usize = 18; + +/// Although we ignore all extensions, we still have to validate the message. +fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option<()> { + for mut chars in parts.map(|s| s.chars()) { + let attr = chars.next()?; + if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) { + return None; + } + let eq = chars.next()?; + if eq != '=' { + return None; + } + } + + Some(()) +} + +#[derive(Debug)] +pub struct ClientFirstMessage<'a> { + /// `client-first-message-bare`. + pub bare: &'a str, + /// Channel binding mode. + pub cbind_flag: ChannelBinding<&'a str>, + /// (Client username)[]. + pub username: &'a str, + /// Client nonce. + pub nonce: &'a str, +} + +impl<'a> ClientFirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let mut parts = input.split(','); + + let cbind_flag = ChannelBinding::parse(parts.next()?)?; + + // PG doesn't support authorization identity, + // so we don't bother defining GS2 header type + let authzid = parts.next()?; + if !authzid.is_empty() { + return None; + } + + // Unfortunately, `parts.as_str()` is unstable + let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1; + let (_, bare) = input.split_at(pos); + + // In theory, these might be preceded by "reserved-mext" (i.e. "m=") + let username = parts.next()?.strip_prefix("n=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + Some(Self { + bare, + cbind_flag, + username, + nonce, + }) + } + + /// Build a response to [`ClientFirstMessage`]. + pub fn build_server_first_message( + &self, + nonce: &[u8; SCRAM_RAW_NONCE_LEN], + salt_base64: &str, + iterations: u32, + ) -> OwnedServerFirstMessage { + use std::fmt::Write; + + let mut message = String::new(); + write!(&mut message, "r={}", self.nonce).unwrap(); + base64::encode_config_buf(nonce, base64::STANDARD, &mut message); + let combined_nonce = 2..message.len(); + write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap(); + + // This design guarantees that it's impossible to create a + // server-first-message without receiving a client-first-message + OwnedServerFirstMessage { + message, + nonce: combined_nonce, + } + } +} + +#[derive(Debug)] +pub struct ClientFinalMessage<'a> { + /// `client-final-message-without-proof`. + pub without_proof: &'a str, + /// Channel binding data (base64). + pub channel_binding: &'a str, + /// Combined client & server nonce. + pub nonce: &'a str, + /// Client auth proof. + pub proof: [u8; SCRAM_KEY_LEN], +} + +impl<'a> ClientFinalMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let (without_proof, proof) = input.rsplit_once(',')?; + + let mut parts = without_proof.split(','); + let channel_binding = parts.next()?.strip_prefix("c=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + let proof = base64_decode_array(proof.strip_prefix("p=")?)?; + + Some(Self { + without_proof, + channel_binding, + nonce, + proof, + }) + } + + /// Build a response to [`ClientFinalMessage`]. + pub fn build_server_final_message( + &self, + signature_builder: SignatureBuilder, + server_key: &ScramKey, + ) -> String { + let mut buf = String::from("v="); + base64::encode_config_buf( + signature_builder.build(server_key), + base64::STANDARD, + &mut buf, + ); + + buf + } +} + +/// We need to keep a convenient representation of this +/// message for the next authentication step. +pub struct OwnedServerFirstMessage { + /// Owned `server-first-message`. + message: String, + /// Slice into `message`. + nonce: Range, +} + +impl OwnedServerFirstMessage { + /// Extract combined nonce from the message. + #[inline(always)] + pub fn nonce(&self) -> &str { + &self.message[self.nonce.clone()] + } + + /// Get reference to a text representation of the message. + #[inline(always)] + pub fn as_str(&self) -> &str { + &self.message + } +} + +impl fmt::Debug for OwnedServerFirstMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerFirstMessage") + .field("message", &self.as_str()) + .field("nonce", &self.nonce()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_client_first_message() { + use ChannelBinding::*; + + // (Almost) real strings captured during debug sessions + let cases = [ + (NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + (NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + ( + Required("tls-server-end-point"), + "p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju", + ), + ]; + + for (cb, input) in cases { + let msg = ClientFirstMessage::parse(input).unwrap(); + + assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.username, "pepe"); + assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.cbind_flag, cb); + } + } + + #[test] + fn parse_client_final_message() { + let input = [ + "c=eSws", + "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU", + "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=", + ] + .join(","); + + let msg = ClientFinalMessage::parse(&input).unwrap(); + assert_eq!( + msg.without_proof, + "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + msg.nonce, + "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + base64::encode(msg.proof), + "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=" + ); + } +} diff --git a/proxy/src/scram/password.rs b/proxy/src/scram/password.rs new file mode 100644 index 0000000000..656780d853 --- /dev/null +++ b/proxy/src/scram/password.rs @@ -0,0 +1,48 @@ +//! Password hashing routines. + +use super::key::ScramKey; + +pub const SALTED_PASSWORD_LEN: usize = 32; + +/// Salted hashed password is essential for [key](super::key) derivation. +#[repr(transparent)] +pub struct SaltedPassword { + bytes: [u8; SALTED_PASSWORD_LEN], +} + +impl SaltedPassword { + /// See `scram-common.c : scram_SaltedPassword` for details. + /// Further reading: (see `PBKDF2`). + pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword { + let one = 1_u32.to_be_bytes(); // magic + + let mut current = super::hmac_sha256(password, [salt, &one]); + let mut result = current; + for _ in 1..iterations { + current = super::hmac_sha256(password, [current.as_ref()]); + // TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094 + for (i, x) in current.iter().enumerate() { + result[i] ^= x; + } + } + + result.into() + } + + /// Derive `ClientKey` from a salted hashed password. + pub fn client_key(&self) -> ScramKey { + super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into() + } + + /// Derive `ServerKey` from a salted hashed password. + pub fn server_key(&self) -> ScramKey { + super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into() + } +} + +impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword { + #[inline(always)] + fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self { + Self { bytes } + } +} diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs new file mode 100644 index 0000000000..e8d180bcdd --- /dev/null +++ b/proxy/src/scram/secret.rs @@ -0,0 +1,116 @@ +//! Tools for SCRAM server secret management. + +use super::base64_decode_array; +use super::key::ScramKey; + +/// Server secret is produced from [password](super::password::SaltedPassword) +/// and is used throughout the authentication process. +#[derive(Debug)] +pub struct ServerSecret { + /// Number of iterations for `PBKDF2` function. + pub iterations: u32, + /// Salt used to hash user's password. + pub salt_base64: String, + /// Hashed `ClientKey`. + pub stored_key: ScramKey, + /// Used by client to verify server's signature. + pub server_key: ScramKey, +} + +impl ServerSecret { + pub fn parse(input: &str) -> Option { + // SCRAM-SHA-256$:$: + let s = input.strip_prefix("SCRAM-SHA-256$")?; + let (params, keys) = s.split_once('$')?; + + let ((iterations, salt), (stored_key, server_key)) = + params.split_once(':').zip(keys.split_once(':'))?; + + let secret = ServerSecret { + iterations: iterations.parse().ok()?, + salt_base64: salt.to_owned(), + stored_key: base64_decode_array(stored_key)?.into(), + server_key: base64_decode_array(server_key)?.into(), + }; + + Some(secret) + } + + /// To avoid revealing information to an attacker, we use a + /// mocked server secret even if the user doesn't exist. + /// See `auth-scram.c : mock_scram_secret` for details. + pub fn mock(user: &str, nonce: &[u8; 32]) -> Self { + // Refer to `auth-scram.c : scram_mock_salt`. + let mocked_salt = super::sha256([user.as_bytes(), nonce]); + + Self { + iterations: 4096, + salt_base64: base64::encode(&mocked_salt), + stored_key: ScramKey::default(), + server_key: ScramKey::default(), + } + } + + /// Build a new server secret from the prerequisites. + /// XXX: We only use this function in tests. + #[cfg(test)] + pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option { + // TODO: implement proper password normalization required by the RFC + if !password.is_ascii() { + return None; + } + + let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations); + + Some(Self { + iterations, + salt_base64: base64::encode(&salt), + stored_key: password.client_key().sha256(), + server_key: password.server_key(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_scram_secret() { + let iterations = 4096; + let salt = "+/tQQax7twvwTj64mjBsxQ=="; + let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns="; + let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI="; + + let secret = format!( + "SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}", + iterations = iterations, + salt = salt, + stored_key = stored_key, + server_key = server_key, + ); + + let parsed = ServerSecret::parse(&secret).unwrap(); + assert_eq!(parsed.iterations, iterations); + assert_eq!(parsed.salt_base64, salt); + + assert_eq!(base64::encode(parsed.stored_key), stored_key); + assert_eq!(base64::encode(parsed.server_key), server_key); + } + + #[test] + fn build_scram_secret() { + let salt = b"salt"; + let secret = ServerSecret::build("password", salt, 4096).unwrap(); + assert_eq!(secret.iterations, 4096); + assert_eq!(secret.salt_base64, base64::encode(salt)); + assert_eq!( + base64::encode(secret.stored_key.as_ref()), + "lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ=" + ); + assert_eq!( + base64::encode(secret.server_key.as_ref()), + "ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw=" + ); + } +} diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs new file mode 100644 index 0000000000..1c2811d757 --- /dev/null +++ b/proxy/src/scram/signature.rs @@ -0,0 +1,66 @@ +//! Tools for client/server signature management. + +use super::key::{ScramKey, SCRAM_KEY_LEN}; + +/// A collection of message parts needed to derive the client's signature. +#[derive(Debug)] +pub struct SignatureBuilder<'a> { + pub client_first_message_bare: &'a str, + pub server_first_message: &'a str, + pub client_final_message_without_proof: &'a str, +} + +impl SignatureBuilder<'_> { + pub fn build(&self, key: &ScramKey) -> Signature { + let parts = [ + self.client_first_message_bare.as_bytes(), + b",", + self.server_first_message.as_bytes(), + b",", + self.client_final_message_without_proof.as_bytes(), + ]; + + super::hmac_sha256(key.as_ref(), parts).into() + } +} + +/// A computed value which, when xored with `ClientProof`, +/// produces `ClientKey` that we need for authentication. +#[derive(Debug)] +#[repr(transparent)] +pub struct Signature { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl Signature { + /// Derive `ClientKey` from client's signature and proof. + pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { + // This is how the proof is calculated: + // + // 1. sha256(ClientKey) -> StoredKey + // 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature + // 3. ClientKey ^ ClientSignature -> ClientProof + // + // Step 3 implies that we can restore ClientKey from the proof + // by xoring the latter with the ClientSignature. Afterwards we + // can check that the presumed ClientKey meets our expectations. + let mut signature = self.bytes; + for (i, x) in proof.iter().enumerate() { + signature[i] ^= x; + } + + signature.into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for Signature { + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for Signature { + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 83792f2aca..f984fb4417 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -375,9 +375,8 @@ impl PostgresBackend { } AuthType::MD5 => { rand::thread_rng().fill(&mut self.md5_salt); - let md5_salt = self.md5_salt; self.write_message(&BeMessage::AuthenticationMD5Password( - &md5_salt, + self.md5_salt, ))?; self.state = ProtoState::Authentication; } diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index cb69418c07..403e176b14 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -401,7 +401,8 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { #[derive(Debug)] pub enum BeMessage<'a> { AuthenticationOk, - AuthenticationMD5Password(&'a [u8; 4]), + AuthenticationMD5Password([u8; 4]), + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), AuthenticationCleartextPassword, BackendKeyData(CancelKeyData), BindComplete, @@ -429,6 +430,13 @@ pub enum BeMessage<'a> { KeepAlive(WalSndKeepAlive), } +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + #[derive(Debug)] pub enum BeParameterStatusMessage<'a> { Encoding(&'a str), @@ -611,6 +619,32 @@ impl<'a> BeMessage<'a> { .unwrap(); // write into BytesMut can't fail } + BeMessage::AuthenticationSasl(msg) => { + buf.put_u8(b'R'); + write_body(buf, |buf| { + use BeAuthenticationSaslMessage::*; + match msg { + Methods(methods) => { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods.iter() { + write_cstr(method.as_bytes(), buf)?; + } + buf.put_u8(0); // zero terminator for the list + } + Continue(extra) => { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + } + Final(extra) => { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + } + } + Ok::<_, io::Error>(()) + }) + .unwrap() + } + BeMessage::BackendKeyData(key_data) => { buf.put_u8(b'K'); write_body(buf, |buf| { From 9b7a8e67a4ccd0957afd46d857d81374126fb255 Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Tue, 12 Apr 2022 23:57:33 +0300 Subject: [PATCH 03/19] fix deadlock in upload_timeline_checkpoint It originated from the fact that we were calling to fetch_full_index without releasing the read guard, and fetch_full_index tries to acquire read again. For plain mutex it is already a deeadlock, for RW lock deadlock was achieved by an attempt to acquire write access later in the code while still having active read guard up in the stack This is sort of a bandaid because Kirill plans to change this code during removal of an archiving mechanism --- .../src/remote_storage/storage_sync/upload.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pageserver/src/remote_storage/storage_sync/upload.rs b/pageserver/src/remote_storage/storage_sync/upload.rs index f955e04474..7b6d58a661 100644 --- a/pageserver/src/remote_storage/storage_sync/upload.rs +++ b/pageserver/src/remote_storage/storage_sync/upload.rs @@ -1,6 +1,6 @@ //! Timeline synchronization logic to compress and upload to the remote storage all new timeline files from the checkpoints. -use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc}; +use std::{collections::BTreeSet, path::PathBuf, sync::Arc}; use tracing::{debug, error, warn}; @@ -46,13 +46,21 @@ pub(super) async fn upload_timeline_checkpoint< let index_read = index.read().await; let remote_timeline = match index_read.timeline_entry(&sync_id) { - None => None, + None => { + drop(index_read); + None + } Some(entry) => match entry.inner() { - TimelineIndexEntryInner::Full(remote_timeline) => Some(Cow::Borrowed(remote_timeline)), + TimelineIndexEntryInner::Full(remote_timeline) => { + let r = Some(remote_timeline.clone()); + drop(index_read); + r + } TimelineIndexEntryInner::Description(_) => { + drop(index_read); debug!("Found timeline description for the given ids, downloading the full index"); match fetch_full_index(remote_assets.as_ref(), &timeline_dir, sync_id).await { - Ok(remote_timeline) => Some(Cow::Owned(remote_timeline)), + Ok(remote_timeline) => Some(remote_timeline), Err(e) => { error!("Failed to download full timeline index: {:?}", e); sync_queue::push(SyncTask::new( @@ -82,7 +90,6 @@ pub(super) async fn upload_timeline_checkpoint< let already_uploaded_files = remote_timeline .map(|timeline| timeline.stored_files(&timeline_dir)) .unwrap_or_default(); - drop(index_read); match try_upload_checkpoint( config, From 20414c4b16143e1757816c1cd015c01c5343b28d Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Wed, 13 Apr 2022 00:20:55 +0300 Subject: [PATCH 04/19] defuse possible deadlock in download_timeline too --- .../src/remote_storage/storage_sync/download.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pageserver/src/remote_storage/storage_sync/download.rs b/pageserver/src/remote_storage/storage_sync/download.rs index 773b4a12e5..e5aa74452b 100644 --- a/pageserver/src/remote_storage/storage_sync/download.rs +++ b/pageserver/src/remote_storage/storage_sync/download.rs @@ -1,6 +1,6 @@ //! Timeline synchrnonization logic to put files from archives on remote storage into pageserver's local directory. -use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc}; +use std::{collections::BTreeSet, path::PathBuf, sync::Arc}; use anyhow::{ensure, Context}; use tokio::fs; @@ -64,11 +64,16 @@ pub(super) async fn download_timeline< let remote_timeline = match index_read.timeline_entry(&sync_id) { None => { error!("Cannot download: no timeline is present in the index for given id"); + drop(index_read); return DownloadedTimeline::Abort; } Some(index_entry) => match index_entry.inner() { - TimelineIndexEntryInner::Full(remote_timeline) => Cow::Borrowed(remote_timeline), + TimelineIndexEntryInner::Full(remote_timeline) => { + let cloned = remote_timeline.clone(); + drop(index_read); + cloned + } TimelineIndexEntryInner::Description(_) => { // we do not check here for awaits_download because it is ok // to call this function while the download is in progress @@ -84,7 +89,7 @@ pub(super) async fn download_timeline< ) .await { - Ok(remote_timeline) => Cow::Owned(remote_timeline), + Ok(remote_timeline) => remote_timeline, Err(e) => { error!("Failed to download full timeline index: {:?}", e); From 87020f81265b14db527177b075e78752becb24cc Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Wed, 13 Apr 2022 10:59:29 +0300 Subject: [PATCH 05/19] Fix CI staging deploy (#1499) - Remove stopped safekeeper from inventory - Fix github pages address after neon rename --- .circleci/ansible/staging.hosts | 1 - .circleci/config.yml | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.circleci/ansible/staging.hosts b/.circleci/ansible/staging.hosts index f6b7bf009f..69f058c2b9 100644 --- a/.circleci/ansible/staging.hosts +++ b/.circleci/ansible/staging.hosts @@ -5,7 +5,6 @@ zenith-us-stage-ps-2 console_region_id=27 [safekeepers] zenith-us-stage-sk-1 console_region_id=27 zenith-us-stage-sk-2 console_region_id=27 -zenith-us-stage-sk-3 console_region_id=27 zenith-us-stage-sk-4 console_region_id=27 [storage:children] diff --git a/.circleci/config.yml b/.circleci/config.yml index 9d26d5d558..f05e64072a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -405,7 +405,7 @@ jobs: - run: name: Build coverage report command: | - COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1 + COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1 scripts/coverage \ --dir=/tmp/zenith/coverage report \ @@ -416,8 +416,8 @@ jobs: name: Upload coverage report command: | LOCAL_REPO=$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME - REPORT_URL=https://zenithdb.github.io/zenith-coverage-data/$CIRCLE_SHA1 - COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1 + REPORT_URL=https://neondatabase.github.io/zenith-coverage-data/$CIRCLE_SHA1 + COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1 scripts/git-upload \ --repo=https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-coverage-data.git \ @@ -593,7 +593,7 @@ jobs: name: Setup helm v3 command: | curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash - helm repo add zenithdb https://zenithdb.github.io/helm-charts + helm repo add zenithdb https://neondatabase.github.io/helm-charts - run: name: Re-deploy proxy command: | @@ -643,7 +643,7 @@ jobs: name: Setup helm v3 command: | curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash - helm repo add zenithdb https://zenithdb.github.io/helm-charts + helm repo add zenithdb https://neondatabase.github.io/helm-charts - run: name: Re-deploy proxy command: | From 58d5136a615f2c42e26ad78c16eb5fff965335df Mon Sep 17 00:00:00 2001 From: Daniil Date: Wed, 13 Apr 2022 17:16:25 +0300 Subject: [PATCH 06/19] compute_tools: check writability handler (#941) --- Cargo.lock | 1 + compute_tools/Cargo.toml | 1 + compute_tools/src/bin/zenith_ctl.rs | 2 ++ compute_tools/src/checker.rs | 46 +++++++++++++++++++++++++++++ compute_tools/src/http_api.rs | 13 ++++++-- compute_tools/src/lib.rs | 1 + 6 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 compute_tools/src/checker.rs diff --git a/Cargo.lock b/Cargo.lock index 7df1c4ab7a..0584b9d6d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -346,6 +346,7 @@ dependencies = [ "serde_json", "tar", "tokio", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "workspace_hack", ] diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 56047093f1..fc52ce4e83 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -17,4 +17,5 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1" tar = "0.4" tokio = { version = "1.17", features = ["macros", "rt", "rt-multi-thread"] } +tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } workspace_hack = { version = "0.1", path = "../workspace_hack" } diff --git a/compute_tools/src/bin/zenith_ctl.rs b/compute_tools/src/bin/zenith_ctl.rs index 49ba653fa1..372afbc633 100644 --- a/compute_tools/src/bin/zenith_ctl.rs +++ b/compute_tools/src/bin/zenith_ctl.rs @@ -38,6 +38,7 @@ use clap::Arg; use log::info; use postgres::{Client, NoTls}; +use compute_tools::checker::create_writablity_check_data; use compute_tools::config; use compute_tools::http_api::launch_http_server; use compute_tools::logger::*; @@ -128,6 +129,7 @@ fn run_compute(state: &Arc>) -> Result { handle_roles(&read_state.spec, &mut client)?; handle_databases(&read_state.spec, &mut client)?; + create_writablity_check_data(&mut client)?; // 'Close' connection drop(client); diff --git a/compute_tools/src/checker.rs b/compute_tools/src/checker.rs new file mode 100644 index 0000000000..63da6ea23e --- /dev/null +++ b/compute_tools/src/checker.rs @@ -0,0 +1,46 @@ +use std::sync::{Arc, RwLock}; + +use anyhow::{anyhow, Result}; +use log::error; +use postgres::Client; +use tokio_postgres::NoTls; + +use crate::zenith::ComputeState; + +pub fn create_writablity_check_data(client: &mut Client) -> Result<()> { + let query = " + CREATE TABLE IF NOT EXISTS health_check ( + id serial primary key, + updated_at timestamptz default now() + ); + INSERT INTO health_check VALUES (1, now()) + ON CONFLICT (id) DO UPDATE + SET updated_at = now();"; + let result = client.simple_query(query)?; + if result.len() < 2 { + return Err(anyhow::format_err!("executed {} queries", result.len())); + } + Ok(()) +} + +pub async fn check_writability(state: &Arc>) -> Result<()> { + let connstr = state.read().unwrap().connstr.clone(); + let (client, connection) = tokio_postgres::connect(&connstr, NoTls).await?; + if client.is_closed() { + return Err(anyhow!("connection to postgres closed")); + } + tokio::spawn(async move { + if let Err(e) = connection.await { + error!("connection error: {}", e); + } + }); + + let result = client + .simple_query("UPDATE health_check SET updated_at = now() WHERE id = 1;") + .await?; + + if result.len() != 1 { + return Err(anyhow!("statement can't be executed")); + } + Ok(()) +} diff --git a/compute_tools/src/http_api.rs b/compute_tools/src/http_api.rs index 02fab08a6e..7e1a876044 100644 --- a/compute_tools/src/http_api.rs +++ b/compute_tools/src/http_api.rs @@ -11,7 +11,7 @@ use log::{error, info}; use crate::zenith::*; // Service function to handle all available routes. -fn routes(req: Request, state: Arc>) -> Response { +async fn routes(req: Request, state: Arc>) -> Response { match (req.method(), req.uri().path()) { // Timestamp of the last Postgres activity in the plain text. (&Method::GET, "/last_activity") => { @@ -29,6 +29,15 @@ fn routes(req: Request, state: Arc>) -> Response { + info!("serving /check_writability GET request"); + let res = crate::checker::check_writability(&state).await; + match res { + Ok(_) => Response::new(Body::from("true")), + Err(e) => Response::new(Body::from(e.to_string())), + } + } + // Return the `404 Not Found` for any other routes. _ => { let mut not_found = Response::new(Body::from("404 Not Found")); @@ -48,7 +57,7 @@ async fn serve(state: Arc>) { async move { Ok::<_, Infallible>(service_fn(move |req: Request| { let state = state.clone(); - async move { Ok::<_, Infallible>(routes(req, state)) } + async move { Ok::<_, Infallible>(routes(req, state).await) } })) } }); diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 592011d95e..ffb9700a49 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -2,6 +2,7 @@ //! Various tools and helpers to handle cluster / compute node (Postgres) //! configuration. //! +pub mod checker; pub mod config; pub mod http_api; #[macro_use] From 1fd08107cab279c8fd0a0a042a5a04ec58a4fe0d Mon Sep 17 00:00:00 2001 From: Dhammika Pathirana Date: Mon, 11 Apr 2022 13:59:26 -0700 Subject: [PATCH 07/19] Add ps compaction_threshold config Signed-off-by: Dhammika Pathirana Add ps compaction_threadhold knob for (#707) (#1484) --- pageserver/src/config.rs | 22 +++++++++++++++++++++- pageserver/src/layered_repository.rs | 8 +++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 0d5cac8b4f..067073cd9b 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -36,8 +36,8 @@ pub mod defaults { // Target file size, when creating image and delta layers. // This parameter determines L1 layer file size. pub const DEFAULT_COMPACTION_TARGET_SIZE: u64 = 128 * 1024 * 1024; - pub const DEFAULT_COMPACTION_PERIOD: &str = "1 s"; + pub const DEFAULT_COMPACTION_THRESHOLD: usize = 10; pub const DEFAULT_GC_HORIZON: u64 = 64 * 1024 * 1024; pub const DEFAULT_GC_PERIOD: &str = "100 s"; @@ -65,6 +65,7 @@ pub mod defaults { #checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes #compaction_target_size = {DEFAULT_COMPACTION_TARGET_SIZE} # in bytes #compaction_period = '{DEFAULT_COMPACTION_PERIOD}' +#compaction_threshold = '{DEFAULT_COMPACTION_THRESHOLD}' #gc_period = '{DEFAULT_GC_PERIOD}' #gc_horizon = {DEFAULT_GC_HORIZON} @@ -107,6 +108,9 @@ pub struct PageServerConf { // How often to check if there's compaction work to be done. pub compaction_period: Duration, + // Level0 delta layer threshold for compaction. + pub compaction_threshold: usize, + pub gc_horizon: u64, pub gc_period: Duration, @@ -162,6 +166,7 @@ struct PageServerConfigBuilder { compaction_target_size: BuilderValue, compaction_period: BuilderValue, + compaction_threshold: BuilderValue, gc_horizon: BuilderValue, gc_period: BuilderValue, @@ -198,6 +203,7 @@ impl Default for PageServerConfigBuilder { compaction_target_size: Set(DEFAULT_COMPACTION_TARGET_SIZE), compaction_period: Set(humantime::parse_duration(DEFAULT_COMPACTION_PERIOD) .expect("cannot parse default compaction period")), + compaction_threshold: Set(DEFAULT_COMPACTION_THRESHOLD), gc_horizon: Set(DEFAULT_GC_HORIZON), gc_period: Set(humantime::parse_duration(DEFAULT_GC_PERIOD) .expect("cannot parse default gc period")), @@ -241,6 +247,10 @@ impl PageServerConfigBuilder { self.compaction_period = BuilderValue::Set(compaction_period) } + pub fn compaction_threshold(&mut self, compaction_threshold: usize) { + self.compaction_threshold = BuilderValue::Set(compaction_threshold) + } + pub fn gc_horizon(&mut self, gc_horizon: u64) { self.gc_horizon = BuilderValue::Set(gc_horizon) } @@ -313,6 +323,9 @@ impl PageServerConfigBuilder { compaction_period: self .compaction_period .ok_or(anyhow::anyhow!("missing compaction_period"))?, + compaction_threshold: self + .compaction_threshold + .ok_or(anyhow::anyhow!("missing compaction_threshold"))?, gc_horizon: self .gc_horizon .ok_or(anyhow::anyhow!("missing gc_horizon"))?, @@ -453,6 +466,9 @@ impl PageServerConf { builder.compaction_target_size(parse_toml_u64(key, item)?) } "compaction_period" => builder.compaction_period(parse_toml_duration(key, item)?), + "compaction_threshold" => { + builder.compaction_threshold(parse_toml_u64(key, item)? as usize) + } "gc_horizon" => builder.gc_horizon(parse_toml_u64(key, item)?), "gc_period" => builder.gc_period(parse_toml_duration(key, item)?), "wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?), @@ -590,6 +606,7 @@ impl PageServerConf { checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE, compaction_target_size: 4 * 1024 * 1024, compaction_period: Duration::from_secs(10), + compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD, gc_horizon: defaults::DEFAULT_GC_HORIZON, gc_period: Duration::from_secs(10), wait_lsn_timeout: Duration::from_secs(60), @@ -662,6 +679,7 @@ checkpoint_distance = 111 # in bytes compaction_target_size = 111 # in bytes compaction_period = '111 s' +compaction_threshold = 2 gc_period = '222 s' gc_horizon = 222 @@ -700,6 +718,7 @@ id = 10 checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE, compaction_target_size: defaults::DEFAULT_COMPACTION_TARGET_SIZE, compaction_period: humantime::parse_duration(defaults::DEFAULT_COMPACTION_PERIOD)?, + compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD, gc_horizon: defaults::DEFAULT_GC_HORIZON, gc_period: humantime::parse_duration(defaults::DEFAULT_GC_PERIOD)?, wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?, @@ -745,6 +764,7 @@ id = 10 checkpoint_distance: 111, compaction_target_size: 111, compaction_period: Duration::from_secs(111), + compaction_threshold: 2, gc_horizon: 222, gc_period: Duration::from_secs(222), wait_lsn_timeout: Duration::from_secs(111), diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index 5e93e3389b..e178ba5222 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -1680,13 +1680,11 @@ impl LayeredTimeline { fn compact_level0(&self, target_file_size: u64) -> Result<()> { let layers = self.layers.lock().unwrap(); - // We compact or "shuffle" the level-0 delta layers when 10 have - // accumulated. - static COMPACT_THRESHOLD: usize = 10; - let level0_deltas = layers.get_level0_deltas()?; - if level0_deltas.len() < COMPACT_THRESHOLD { + // We compact or "shuffle" the level-0 delta layers when they've + // accumulated over the compaction threshold. + if level0_deltas.len() < self.conf.compaction_threshold { return Ok(()); } drop(layers); From 49da76237bd073f3f5857d6476e7a2827115cadb Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Wed, 13 Apr 2022 18:56:27 +0300 Subject: [PATCH 08/19] remove noisy debug log message --- pageserver/src/layered_repository/block_io.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/pageserver/src/layered_repository/block_io.rs b/pageserver/src/layered_repository/block_io.rs index 2eba0aa403..d027b2f0e7 100644 --- a/pageserver/src/layered_repository/block_io.rs +++ b/pageserver/src/layered_repository/block_io.rs @@ -198,7 +198,6 @@ impl BlockWriter for BlockBuf { assert!(buf.len() == PAGE_SZ); let blknum = self.blocks.len(); self.blocks.push(buf); - tracing::info!("buffered block {}", blknum); Ok(blknum as u32) } } From 1d36c5a39e97006daa63b3cb2af0dee3cf1ee3e4 Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Wed, 13 Apr 2022 19:19:44 +0300 Subject: [PATCH 09/19] reenable s3 on staging pagservers by default After deadlockk fix in https://github.com/neondatabase/neon/pull/1496 s3 seems to work normally. There is one more discovered issue but it is not a blocker so can be fixed separately. --- .circleci/ansible/deploy.yaml | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/.circleci/ansible/deploy.yaml b/.circleci/ansible/deploy.yaml index 2112102aa7..508843812a 100644 --- a/.circleci/ansible/deploy.yaml +++ b/.circleci/ansible/deploy.yaml @@ -63,21 +63,18 @@ tags: - pageserver - # It seems that currently S3 integration does not play well - # even with fresh pageserver without a burden of old data. - # TODO: turn this back on once the issue is solved. - # - name: update remote storage (s3) config - # lineinfile: - # path: /storage/pageserver/data/pageserver.toml - # line: "{{ item }}" - # loop: - # - "[remote_storage]" - # - "bucket_name = '{{ bucket_name }}'" - # - "bucket_region = '{{ bucket_region }}'" - # - "prefix_in_bucket = '{{ inventory_hostname }}'" - # become: true - # tags: - # - pageserver + - name: update remote storage (s3) config + lineinfile: + path: /storage/pageserver/data/pageserver.toml + line: "{{ item }}" + loop: + - "[remote_storage]" + - "bucket_name = '{{ bucket_name }}'" + - "bucket_region = '{{ bucket_region }}'" + - "prefix_in_bucket = '{{ inventory_hostname }}'" + become: true + tags: + - pageserver - name: upload systemd service definition ansible.builtin.template: From a0781f229c5574ab4fdae6b63175b7da8846921d Mon Sep 17 00:00:00 2001 From: Dhammika Pathirana Date: Wed, 13 Apr 2022 14:08:42 -0700 Subject: [PATCH 10/19] Add ps compact command Signed-off-by: Dhammika Pathirana Add ps compact command to api (#707) (#1484) --- pageserver/src/page_service.rs | 20 ++++++++++++++++++++ pageserver/src/repository.rs | 6 ++++-- test_runner/fixtures/compare_fixtures.py | 3 +++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index e7a4117b3e..c09b032e48 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -713,6 +713,26 @@ impl postgres_backend::Handler for PageServerHandler { Some(result.elapsed.as_millis().to_string().as_bytes()), ]))? .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with("compact ") { + // Run compaction immediately on given timeline. + // FIXME This is just for tests. Don't expect this to be exposed to + // the users or the api. + + // compact + let re = Regex::new(r"^compact ([[:xdigit:]]+)\s([[:xdigit:]]+)($|\s)?").unwrap(); + + let caps = re + .captures(query_string) + .with_context(|| format!("Invalid compact: '{}'", query_string))?; + + let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?; + let timelineid = ZTimelineId::from_str(caps.get(2).unwrap().as_str())?; + let timeline = tenant_mgr::get_timeline_for_tenant_load(tenantid, timelineid) + .context("Couldn't load timeline")?; + timeline.tline.compact()?; + + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; } else if query_string.starts_with("checkpoint ") { // Run checkpoint immediately on given timeline. diff --git a/pageserver/src/repository.rs b/pageserver/src/repository.rs index 02334d3229..eda9a3168d 100644 --- a/pageserver/src/repository.rs +++ b/pageserver/src/repository.rs @@ -252,8 +252,10 @@ pub trait Repository: Send + Sync { checkpoint_before_gc: bool, ) -> Result; - /// perform one compaction iteration. - /// this function is periodically called by compactor thread. + /// Perform one compaction iteration. + /// This function is periodically called by compactor thread. + /// Also it can be explicitly requested per timeline through page server + /// api's 'compact' command. fn compaction_iteration(&self) -> Result<()>; /// detaches locally available timeline by stopping all threads and removing all the data. diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 750b02c894..598ee10f8e 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -87,6 +87,9 @@ class ZenithCompare(PgCompare): def flush(self): self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0") + def compact(self): + self.pscur.execute(f"compact {self.env.initial_tenant.hex} {self.timeline}") + def report_peak_memory_use(self) -> None: self.zenbenchmark.record("peak_mem", self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024, From cdf04b6a9fb2d5d225d12a2a74fae6c6eec26da6 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Thu, 14 Apr 2022 09:31:35 +0300 Subject: [PATCH 11/19] Fix control file updates in safekeeper (#1452) Now control_file::Storage implements Deref for read-only access to the state. All updates should clone the state before modifying and persisting. --- walkeeper/src/control_file.rs | 57 ++++++++++++--- walkeeper/src/safekeeper.rs | 126 ++++++++++++++++++++-------------- walkeeper/src/timeline.rs | 16 ++--- 3 files changed, 127 insertions(+), 72 deletions(-) diff --git a/walkeeper/src/control_file.rs b/walkeeper/src/control_file.rs index 8b4e618661..7cc53edeb0 100644 --- a/walkeeper/src/control_file.rs +++ b/walkeeper/src/control_file.rs @@ -6,6 +6,7 @@ use lazy_static::lazy_static; use std::fs::{self, File, OpenOptions}; use std::io::{Read, Write}; +use std::ops::Deref; use std::path::{Path, PathBuf}; use tracing::*; @@ -37,8 +38,10 @@ lazy_static! { .expect("Failed to register safekeeper_persist_control_file_seconds histogram vec"); } -pub trait Storage { - /// Persist safekeeper state on disk. +/// Storage should keep actual state inside of it. It should implement Deref +/// trait to access state fields and have persist method for updating that state. +pub trait Storage: Deref { + /// Persist safekeeper state on disk and update internal state. fn persist(&mut self, s: &SafeKeeperState) -> Result<()>; } @@ -48,19 +51,47 @@ pub struct FileStorage { timeline_dir: PathBuf, conf: SafeKeeperConf, persist_control_file_seconds: Histogram, + + /// Last state persisted to disk. + state: SafeKeeperState, } impl FileStorage { - pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage { + pub fn restore_new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> Result { let timeline_dir = conf.timeline_dir(zttid); let tenant_id = zttid.tenant_id.to_string(); let timeline_id = zttid.timeline_id.to_string(); - FileStorage { + + let state = Self::load_control_file_conf(conf, zttid)?; + + Ok(FileStorage { timeline_dir, conf: conf.clone(), persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS .with_label_values(&[&tenant_id, &timeline_id]), - } + state, + }) + } + + pub fn create_new( + zttid: &ZTenantTimelineId, + conf: &SafeKeeperConf, + state: SafeKeeperState, + ) -> Result { + let timeline_dir = conf.timeline_dir(zttid); + let tenant_id = zttid.tenant_id.to_string(); + let timeline_id = zttid.timeline_id.to_string(); + + let mut store = FileStorage { + timeline_dir, + conf: conf.clone(), + persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS + .with_label_values(&[&tenant_id, &timeline_id]), + state: state.clone(), + }; + + store.persist(&state)?; + Ok(store) } // Check the magic/version in the on-disk data and deserialize it, if possible. @@ -141,6 +172,14 @@ impl FileStorage { } } +impl Deref for FileStorage { + type Target = SafeKeeperState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + impl Storage for FileStorage { // persists state durably to underlying storage // for description see https://lwn.net/Articles/457667/ @@ -201,6 +240,9 @@ impl Storage for FileStorage { .and_then(|f| f.sync_all()) .context("failed to sync control file directory")?; } + + // update internal state + self.state = s.clone(); Ok(()) } } @@ -228,7 +270,7 @@ mod test { ) -> Result<(FileStorage, SafeKeeperState)> { fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir"); Ok(( - FileStorage::new(zttid, conf), + FileStorage::restore_new(zttid, conf)?, FileStorage::load_control_file_conf(conf, zttid)?, )) } @@ -239,8 +281,7 @@ mod test { ) -> Result<(FileStorage, SafeKeeperState)> { fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir"); let state = SafeKeeperState::empty(); - let mut storage = FileStorage::new(zttid, conf); - storage.persist(&state)?; + let storage = FileStorage::create_new(zttid, conf, state.clone())?; Ok((storage, state)) } diff --git a/walkeeper/src/safekeeper.rs b/walkeeper/src/safekeeper.rs index 1e23d87b34..22a8481e45 100644 --- a/walkeeper/src/safekeeper.rs +++ b/walkeeper/src/safekeeper.rs @@ -210,6 +210,7 @@ pub struct SafekeeperMemState { pub s3_wal_lsn: Lsn, // TODO: keep only persistent version pub peer_horizon_lsn: Lsn, pub remote_consistent_lsn: Lsn, + pub proposer_uuid: PgUuid, } impl SafeKeeperState { @@ -502,9 +503,8 @@ pub struct SafeKeeper { epoch_start_lsn: Lsn, pub inmem: SafekeeperMemState, // in memory part - pub s: SafeKeeperState, // persistent part + pub state: CTRL, // persistent state storage - pub control_store: CTRL, pub wal_store: WAL, } @@ -516,14 +516,14 @@ where // constructor pub fn new( ztli: ZTimelineId, - control_store: CTRL, + state: CTRL, mut wal_store: WAL, - state: SafeKeeperState, ) -> Result> { if state.timeline_id != ZTimelineId::from([0u8; 16]) && ztli != state.timeline_id { bail!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.timeline_id); } + // initialize wal_store, if state is already initialized wal_store.init_storage(&state)?; Ok(SafeKeeper { @@ -535,23 +535,25 @@ where s3_wal_lsn: state.s3_wal_lsn, peer_horizon_lsn: state.peer_horizon_lsn, remote_consistent_lsn: state.remote_consistent_lsn, + proposer_uuid: state.proposer_uuid, }, - s: state, - control_store, + state, wal_store, }) } /// Get history of term switches for the available WAL fn get_term_history(&self) -> TermHistory { - self.s + self.state .acceptor_state .term_history .up_to(self.wal_store.flush_lsn()) } pub fn get_epoch(&self) -> Term { - self.s.acceptor_state.get_epoch(self.wal_store.flush_lsn()) + self.state + .acceptor_state + .get_epoch(self.wal_store.flush_lsn()) } /// Process message from proposer and possibly form reply. Concurrent @@ -587,46 +589,47 @@ where ); } /* Postgres upgrade is not treated as fatal error */ - if msg.pg_version != self.s.server.pg_version - && self.s.server.pg_version != UNKNOWN_SERVER_VERSION + if msg.pg_version != self.state.server.pg_version + && self.state.server.pg_version != UNKNOWN_SERVER_VERSION { info!( "incompatible server version {}, expected {}", - msg.pg_version, self.s.server.pg_version + msg.pg_version, self.state.server.pg_version ); } - if msg.tenant_id != self.s.tenant_id { + if msg.tenant_id != self.state.tenant_id { bail!( "invalid tenant ID, got {}, expected {}", msg.tenant_id, - self.s.tenant_id + self.state.tenant_id ); } - if msg.ztli != self.s.timeline_id { + if msg.ztli != self.state.timeline_id { bail!( "invalid timeline ID, got {}, expected {}", msg.ztli, - self.s.timeline_id + self.state.timeline_id ); } // set basic info about server, if not yet // TODO: verify that is doesn't change after - self.s.server.system_id = msg.system_id; - self.s.server.wal_seg_size = msg.wal_seg_size; - self.control_store - .persist(&self.s) - .context("failed to persist shared state")?; + { + let mut state = self.state.clone(); + state.server.system_id = msg.system_id; + state.server.wal_seg_size = msg.wal_seg_size; + self.state.persist(&state)?; + } // pass wal_seg_size to read WAL and find flush_lsn - self.wal_store.init_storage(&self.s)?; + self.wal_store.init_storage(&self.state)?; info!( "processed greeting from proposer {:?}, sending term {:?}", - msg.proposer_id, self.s.acceptor_state.term + msg.proposer_id, self.state.acceptor_state.term ); Ok(Some(AcceptorProposerMessage::Greeting(AcceptorGreeting { - term: self.s.acceptor_state.term, + term: self.state.acceptor_state.term, }))) } @@ -637,17 +640,19 @@ where ) -> Result> { // initialize with refusal let mut resp = VoteResponse { - term: self.s.acceptor_state.term, + term: self.state.acceptor_state.term, vote_given: false as u64, flush_lsn: self.wal_store.flush_lsn(), - truncate_lsn: self.s.peer_horizon_lsn, + truncate_lsn: self.state.peer_horizon_lsn, term_history: self.get_term_history(), }; - if self.s.acceptor_state.term < msg.term { - self.s.acceptor_state.term = msg.term; + if self.state.acceptor_state.term < msg.term { + let mut state = self.state.clone(); + state.acceptor_state.term = msg.term; // persist vote before sending it out - self.control_store.persist(&self.s)?; - resp.term = self.s.acceptor_state.term; + self.state.persist(&state)?; + + resp.term = self.state.acceptor_state.term; resp.vote_given = true as u64; } info!("processed VoteRequest for term {}: {:?}", msg.term, &resp); @@ -656,9 +661,10 @@ where /// Bump our term if received a note from elected proposer with higher one fn bump_if_higher(&mut self, term: Term) -> Result<()> { - if self.s.acceptor_state.term < term { - self.s.acceptor_state.term = term; - self.control_store.persist(&self.s)?; + if self.state.acceptor_state.term < term { + let mut state = self.state.clone(); + state.acceptor_state.term = term; + self.state.persist(&state)?; } Ok(()) } @@ -666,9 +672,9 @@ where /// Form AppendResponse from current state. fn append_response(&self) -> AppendResponse { let ar = AppendResponse { - term: self.s.acceptor_state.term, + term: self.state.acceptor_state.term, flush_lsn: self.wal_store.flush_lsn(), - commit_lsn: self.s.commit_lsn, + commit_lsn: self.state.commit_lsn, // will be filled by the upper code to avoid bothering safekeeper hs_feedback: HotStandbyFeedback::empty(), zenith_feedback: ZenithFeedback::empty(), @@ -681,7 +687,7 @@ where info!("received ProposerElected {:?}", msg); self.bump_if_higher(msg.term)?; // If our term is higher, ignore the message (next feedback will inform the compute) - if self.s.acceptor_state.term > msg.term { + if self.state.acceptor_state.term > msg.term { return Ok(None); } @@ -692,8 +698,11 @@ where self.wal_store.truncate_wal(msg.start_streaming_at)?; // and now adopt term history from proposer - self.s.acceptor_state.term_history = msg.term_history.clone(); - self.control_store.persist(&self.s)?; + { + let mut state = self.state.clone(); + state.acceptor_state.term_history = msg.term_history.clone(); + self.state.persist(&state)?; + } info!("start receiving WAL since {:?}", msg.start_streaming_at); @@ -715,13 +724,13 @@ where // Also note that commit_lsn can reach epoch_start_lsn earlier // that we receive new epoch_start_lsn, and we still need to sync // control file in this case. - if commit_lsn == self.epoch_start_lsn && self.s.commit_lsn != commit_lsn { + if commit_lsn == self.epoch_start_lsn && self.state.commit_lsn != commit_lsn { self.persist_control_file()?; } // We got our first commit_lsn, which means we should sync // everything to disk, to initialize the state. - if self.s.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) { + if self.state.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) { self.wal_store.flush_wal()?; self.persist_control_file()?; } @@ -731,10 +740,12 @@ where /// Persist in-memory state to the disk. fn persist_control_file(&mut self) -> Result<()> { - self.s.commit_lsn = self.inmem.commit_lsn; - self.s.peer_horizon_lsn = self.inmem.peer_horizon_lsn; + let mut state = self.state.clone(); - self.control_store.persist(&self.s) + state.commit_lsn = self.inmem.commit_lsn; + state.peer_horizon_lsn = self.inmem.peer_horizon_lsn; + state.proposer_uuid = self.inmem.proposer_uuid; + self.state.persist(&state) } /// Handle request to append WAL. @@ -744,13 +755,13 @@ where msg: &AppendRequest, require_flush: bool, ) -> Result> { - if self.s.acceptor_state.term < msg.h.term { + if self.state.acceptor_state.term < msg.h.term { bail!("got AppendRequest before ProposerElected"); } // If our term is higher, immediately refuse the message. - if self.s.acceptor_state.term > msg.h.term { - let resp = AppendResponse::term_only(self.s.acceptor_state.term); + if self.state.acceptor_state.term > msg.h.term { + let resp = AppendResponse::term_only(self.state.acceptor_state.term); return Ok(Some(AcceptorProposerMessage::AppendResponse(resp))); } @@ -758,8 +769,7 @@ where // processing the message. self.epoch_start_lsn = msg.h.epoch_start_lsn; - // TODO: don't update state without persisting to disk - self.s.proposer_uuid = msg.h.proposer_uuid; + self.inmem.proposer_uuid = msg.h.proposer_uuid; // do the job if !msg.wal_data.is_empty() { @@ -790,7 +800,7 @@ where // Update truncate and commit LSN in control file. // To avoid negative impact on performance of extra fsync, do it only // when truncate_lsn delta exceeds WAL segment size. - if self.s.peer_horizon_lsn + (self.s.server.wal_seg_size as u64) + if self.state.peer_horizon_lsn + (self.state.server.wal_seg_size as u64) < self.inmem.peer_horizon_lsn { self.persist_control_file()?; @@ -829,6 +839,8 @@ where #[cfg(test)] mod tests { + use std::ops::Deref; + use super::*; use crate::wal_storage::Storage; @@ -844,6 +856,14 @@ mod tests { } } + impl Deref for InMemoryState { + type Target = SafeKeeperState; + + fn deref(&self) -> &Self::Target { + &self.persisted_state + } + } + struct DummyWalStore { lsn: Lsn, } @@ -879,7 +899,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); - let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap(); + let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap(); // check voting for 1 is ok let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 }); @@ -890,11 +910,11 @@ mod tests { } // reboot... - let state = sk.control_store.persisted_state.clone(); + let state = sk.state.persisted_state.clone(); let storage = InMemoryState { - persisted_state: state.clone(), + persisted_state: state, }; - sk = SafeKeeper::new(ztli, storage, sk.wal_store, state).unwrap(); + sk = SafeKeeper::new(ztli, storage, sk.wal_store).unwrap(); // and ensure voting second time for 1 is not ok vote_resp = sk.process_msg(&vote_request); @@ -911,7 +931,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); - let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap(); + let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap(); let mut ar_hdr = AppendRequestHeader { term: 1, diff --git a/walkeeper/src/timeline.rs b/walkeeper/src/timeline.rs index a76ef77615..a2941a9a5c 100644 --- a/walkeeper/src/timeline.rs +++ b/walkeeper/src/timeline.rs @@ -21,7 +21,6 @@ use crate::broker::SafekeeperInfo; use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey}; use crate::control_file; -use crate::control_file::Storage as cf_storage; use crate::safekeeper::{ AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, SafekeeperMemState, @@ -98,10 +97,9 @@ impl SharedState { peer_ids: Vec, ) -> Result { let state = SafeKeeperState::new(zttid, peer_ids); - let control_store = control_file::FileStorage::new(zttid, conf); + let control_store = control_file::FileStorage::create_new(zttid, conf, state)?; let wal_store = wal_storage::PhysicalStorage::new(zttid, conf); - let mut sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?; - sk.control_store.persist(&sk.s)?; + let sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?; Ok(Self { notified_commit_lsn: Lsn(0), @@ -116,18 +114,14 @@ impl SharedState { /// Restore SharedState from control file. /// If file doesn't exist, bails out. fn restore(conf: &SafeKeeperConf, zttid: &ZTenantTimelineId) -> Result { - let state = control_file::FileStorage::load_control_file_conf(conf, zttid) - .context("failed to load from control file")?; - - let control_store = control_file::FileStorage::new(zttid, conf); - + let control_store = control_file::FileStorage::restore_new(zttid, conf)?; let wal_store = wal_storage::PhysicalStorage::new(zttid, conf); info!("timeline {} restored", zttid.timeline_id); Ok(Self { notified_commit_lsn: Lsn(0), - sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?, + sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?, replicas: Vec::new(), active: false, num_computes: 0, @@ -419,7 +413,7 @@ impl Timeline { pub fn get_state(&self) -> (SafekeeperMemState, SafeKeeperState) { let shared_state = self.mutex.lock().unwrap(); - (shared_state.sk.inmem.clone(), shared_state.sk.s.clone()) + (shared_state.sk.inmem.clone(), shared_state.sk.state.clone()) } /// Prepare public safekeeper info for reporting. From 570db6f1681b80e50dbc2d156d037b99ca742099 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Thu, 14 Apr 2022 11:28:38 +0300 Subject: [PATCH 12/19] Update README for Zenith -> Neon renaming. There's a lot of renaming left to do in the code and docs, but this is a start. Our binaries and many other things are still called "zenith", but I didn't change those in the README, because otherwise the examples won't work. I added a brief note at the top of the README to explain that we're in the process of renaming, until we've renamed everything. --- README.md | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index c8acf526b9..f99785e683 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,22 @@ -# Zenith +# Neon -Zenith is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes. +Neon is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes. + +The project used to be called "Zenith". Many of the commands and code comments +still refer to "zenith", but we are in the process of renaming things. ## Architecture overview -A Zenith installation consists of compute nodes and Zenith storage engine. +A Neon installation consists of compute nodes and Neon storage engine. -Compute nodes are stateless PostgreSQL nodes, backed by Zenith storage engine. +Compute nodes are stateless PostgreSQL nodes, backed by Neon storage engine. -Zenith storage engine consists of two major components: +Neon storage engine consists of two major components: - Pageserver. Scalable storage backend for compute nodes. - WAL service. The service that receives WAL from compute node and ensures that it is stored durably. Pageserver consists of: -- Repository - Zenith storage implementation. +- Repository - Neon storage implementation. - WAL receiver - service that receives WAL from WAL service and stores it in the repository. - Page service - service that communicates with compute nodes and responds with pages from the repository. - WAL redo - service that builds pages from base images and WAL records on Page service request. @@ -35,10 +38,10 @@ To run the `psql` client, install the `postgresql-client` package or modify `PAT To run the integration tests or Python scripts (not required to use the code), install Python (3.7 or higher), and install python3 packages using `./scripts/pysync` (requires poetry) in the project directory. -2. Build zenith and patched postgres +2. Build neon and patched postgres ```sh -git clone --recursive https://github.com/zenithdb/zenith.git -cd zenith +git clone --recursive https://github.com/neondatabase/neon.git +cd neon make -j5 ``` @@ -126,7 +129,7 @@ INSERT 0 1 ## Running tests ```sh -git clone --recursive https://github.com/zenithdb/zenith.git +git clone --recursive https://github.com/neondatabase/neon.git make # builds also postgres and installs it to ./tmp_install ./scripts/pytest ``` @@ -141,14 +144,14 @@ To view your `rustdoc` documentation in a browser, try running `cargo doc --no-d ### Postgres-specific terms -Due to Zenith's very close relation with PostgreSQL internals, there are numerous specific terms used. +Due to Neon's very close relation with PostgreSQL internals, there are numerous specific terms used. Same applies to certain spelling: i.e. we use MB to denote 1024 * 1024 bytes, while MiB would be technically more correct, it's inconsistent with what PostgreSQL code and its documentation use. To get more familiar with this aspect, refer to: -- [Zenith glossary](/docs/glossary.md) +- [Neon glossary](/docs/glossary.md) - [PostgreSQL glossary](https://www.postgresql.org/docs/13/glossary.html) -- Other PostgreSQL documentation and sources (Zenith fork sources can be found [here](https://github.com/zenithdb/postgres)) +- Other PostgreSQL documentation and sources (Neon fork sources can be found [here](https://github.com/neondatabase/postgres)) ## Join the development From 19954dfd8abe154b0db17d7eb45a04acec35cbaf Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Thu, 14 Apr 2022 13:31:37 +0300 Subject: [PATCH 13/19] Refactor proxy options test to not rely on the 'schema' argument. It was the only test that used the 'schema' argument to the connect() function. I'm about to refactor the option handling and will remove the special 'schema' argument altogether, so rewrite the test to not use it. --- test_runner/batch_others/test_proxy.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/test_runner/batch_others/test_proxy.py b/test_runner/batch_others/test_proxy.py index d2039f9758..a6f828f829 100644 --- a/test_runner/batch_others/test_proxy.py +++ b/test_runner/batch_others/test_proxy.py @@ -5,11 +5,14 @@ def test_proxy_select_1(static_proxy): static_proxy.safe_psql("select 1;") -@pytest.mark.xfail # Proxy eats the extra connection options +# Pass extra options to the server. +# +# Currently, proxy eats the extra connection options, so this fails. +# See https://github.com/neondatabase/neon/issues/1287 +@pytest.mark.xfail def test_proxy_options(static_proxy): - schema_name = "tmp_schema_1" - with static_proxy.connect(schema=schema_name) as conn: + with static_proxy.connect(options="-cproxytest.option=value") as conn: with conn.cursor() as cur: - cur.execute("SHOW search_path;") - search_path = cur.fetchall()[0][0] - assert schema_name == search_path + cur.execute("SHOW proxytest.option;") + value = cur.fetchall()[0][0] + assert value == 'value' From a009fe912a292c0df4479c98c4bb5d62c91e7b68 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Thu, 14 Apr 2022 13:31:40 +0300 Subject: [PATCH 14/19] Refactor connection option handling in python tests The PgProtocol.connect() function took extra options for username, database, etc. Remove those options, and have a generic way for each subclass of PgProtocol to provide some default options, with the capability override them in the connect() call. --- test_runner/batch_others/test_createuser.py | 2 +- .../batch_others/test_parallel_copy.py | 5 + test_runner/batch_others/test_wal_acceptor.py | 2 +- .../batch_pg_regress/test_isolation.py | 6 +- .../batch_pg_regress/test_pg_regress.py | 6 +- .../batch_pg_regress/test_zenith_regress.py | 6 +- test_runner/fixtures/zenith_fixtures.py | 128 ++++++++---------- 7 files changed, 69 insertions(+), 86 deletions(-) diff --git a/test_runner/batch_others/test_createuser.py b/test_runner/batch_others/test_createuser.py index efb2af3f07..f4bbbc8a7a 100644 --- a/test_runner/batch_others/test_createuser.py +++ b/test_runner/batch_others/test_createuser.py @@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv): pg2 = env.postgres.create_start('test_createuser2') # Test that you can connect to new branch as a new user - assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )] + assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )] diff --git a/test_runner/batch_others/test_parallel_copy.py b/test_runner/batch_others/test_parallel_copy.py index 4b7cc58d42..a44acecf21 100644 --- a/test_runner/batch_others/test_parallel_copy.py +++ b/test_runner/batch_others/test_parallel_copy.py @@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str) copy_input = repeat_bytes(buf.read(), 5000) pg_conn = await pg.connect_async() + + # PgProtocol.connect_async sets statement_timeout to 2 minutes. + # That's not enough for this test, on a slow system in debug mode. + await pg_conn.execute("SET statement_timeout='300s'") + await pg_conn.copy_to_table(table_name, source=copy_input) diff --git a/test_runner/batch_others/test_wal_acceptor.py b/test_runner/batch_others/test_wal_acceptor.py index 8f87ff041f..dffcd7cc61 100644 --- a/test_runner/batch_others/test_wal_acceptor.py +++ b/test_runner/batch_others/test_wal_acceptor.py @@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol): tenant_id: uuid.UUID, listen_addr: str, port: int): - super().__init__(host=listen_addr, port=port, username='zenith_admin') + super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres') self.pgdata_dir: str = pgdata_dir self.pg_bin: PgBin = pg_bin diff --git a/test_runner/batch_pg_regress/test_isolation.py b/test_runner/batch_pg_regress/test_isolation.py index ddafc3815b..cde56d9b88 100644 --- a/test_runner/batch_pg_regress/test_isolation.py +++ b/test_runner/batch_pg_regress/test_isolation.py @@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys ] env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/batch_pg_regress/test_pg_regress.py b/test_runner/batch_pg_regress/test_pg_regress.py index 5199f65216..07d2574f4a 100644 --- a/test_runner/batch_pg_regress/test_pg_regress.py +++ b/test_runner/batch_pg_regress/test_pg_regress.py @@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin, ] env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/batch_pg_regress/test_zenith_regress.py b/test_runner/batch_pg_regress/test_zenith_regress.py index 31d5b07093..2b57137d16 100644 --- a/test_runner/batch_pg_regress/test_zenith_regress.py +++ b/test_runner/batch_pg_regress/test_zenith_regress.py @@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c log.info(pg_regress_command) env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/fixtures/zenith_fixtures.py b/test_runner/fixtures/zenith_fixtures.py index a95809687a..41d1443880 100644 --- a/test_runner/fixtures/zenith_fixtures.py +++ b/test_runner/fixtures/zenith_fixtures.py @@ -27,6 +27,7 @@ from dataclasses import dataclass # Type-related stuff from psycopg2.extensions import connection as PgConnection +from psycopg2.extensions import make_dsn, parse_dsn from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple from typing_extensions import Literal @@ -238,98 +239,69 @@ def port_distributor(worker_base_port): class PgProtocol: """ Reusable connection logic """ - def __init__(self, - host: str, - port: int, - username: Optional[str] = None, - password: Optional[str] = None, - dbname: Optional[str] = None, - schema: Optional[str] = None): - self.host = host - self.port = port - self.username = username - self.password = password - self.dbname = dbname - self.schema = schema + def __init__(self, **kwargs): + self.default_options = kwargs - def connstr(self, - *, - dbname: Optional[str] = None, - schema: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - statement_timeout_ms: Optional[int] = None) -> str: + def connstr(self, **kwargs) -> str: """ Build a libpq connection string for the Postgres instance. """ + return str(make_dsn(**self.conn_options(**kwargs))) - username = username or self.username - password = password or self.password - dbname = dbname or self.dbname or "postgres" - schema = schema or self.schema - res = f'host={self.host} port={self.port} dbname={dbname}' + def conn_options(self, **kwargs): + conn_options = self.default_options.copy() + if 'dsn' in kwargs: + conn_options.update(parse_dsn(kwargs['dsn'])) + conn_options.update(kwargs) - if username: - res = f'{res} user={username}' - - if password: - res = f'{res} password={password}' - - if schema: - res = f"{res} options='-c search_path={schema}'" - - if statement_timeout_ms: - res = f"{res} options='-c statement_timeout={statement_timeout_ms}'" - - return res + # Individual statement timeout in seconds. 2 minutes should be + # enough for our tests, but if you need a longer, you can + # change it by calling "SET statement_timeout" after + # connecting. + if 'options' in conn_options: + conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options'] + else: + conn_options['options'] = "-cstatement_timeout=120s" + return conn_options # autocommit=True here by default because that's what we need most of the time - def connect( - self, - *, - autocommit=True, - dbname: Optional[str] = None, - schema: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - # individual statement timeout in seconds, 2 minutes should be enough for our tests - statement_timeout: Optional[int] = 120 - ) -> PgConnection: + def connect(self, autocommit=True, **kwargs) -> PgConnection: """ Connect to the node. Returns psycopg2's connection object. This method passes all extra params to connstr. """ + conn = psycopg2.connect(**self.conn_options(**kwargs)) - conn = psycopg2.connect( - self.connstr(dbname=dbname, - schema=schema, - username=username, - password=password, - statement_timeout_ms=statement_timeout * - 1000 if statement_timeout else None)) # WARNING: this setting affects *all* tests! conn.autocommit = autocommit return conn - async def connect_async(self, - *, - dbname: str = 'postgres', - username: Optional[str] = None, - password: Optional[str] = None) -> asyncpg.Connection: + async def connect_async(self, **kwargs) -> asyncpg.Connection: """ Connect to the node from async python. Returns asyncpg's connection object. """ - conn = await asyncpg.connect( - host=self.host, - port=self.port, - database=dbname, - user=username or self.username, - password=password, - ) - return conn + # asyncpg takes slightly different options than psycopg2. Try + # to convert the defaults from the psycopg2 format. + + # The psycopg2 option 'dbname' is called 'database' is asyncpg + conn_options = self.conn_options(**kwargs) + if 'dbname' in conn_options: + conn_options['database'] = conn_options.pop('dbname') + + # Convert options='-c=' to server_settings + if 'options' in conn_options: + options = conn_options.pop('options') + for match in re.finditer('-c(\w*)=(\w*)', options): + key = match.group(1) + val = match.group(2) + if 'server_options' in conn_options: + conn_options['server_settings'].update({key: val}) + else: + conn_options['server_settings'] = {key: val} + return await asyncpg.connect(**conn_options) def safe_psql(self, query: str, **kwargs: Any) -> List[Any]: """ @@ -1149,10 +1121,10 @@ class ZenithPageserver(PgProtocol): port: PageserverPort, remote_storage: Optional[RemoteStorage] = None, config_override: Optional[str] = None): - super().__init__(host='localhost', port=port.pg, username='zenith_admin') + super().__init__(host='localhost', port=port.pg, user='zenith_admin') self.env = env self.running = False - self.service_port = port # do not shadow PgProtocol.port which is just int + self.service_port = port self.remote_storage = remote_storage self.config_override = config_override @@ -1291,7 +1263,7 @@ def pg_bin(test_output_dir: str) -> PgBin: class VanillaPostgres(PgProtocol): def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int): - super().__init__(host='localhost', port=port) + super().__init__(host='localhost', port=port, dbname='postgres') self.pgdatadir = pgdatadir self.pg_bin = pg_bin self.running = False @@ -1335,8 +1307,14 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]: class ZenithProxy(PgProtocol): def __init__(self, port: int): - super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port) + super().__init__(host="127.0.0.1", + user="pytest", + password="pytest", + port=port, + dbname='postgres') self.http_port = 7001 + self.host = "127.0.0.1" + self.port = port self._popen: Optional[subprocess.Popen[bytes]] = None def start_static(self, addr="127.0.0.1:5432") -> None: @@ -1380,13 +1358,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]: class Postgres(PgProtocol): """ An object representing a running postgres daemon. """ def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int): - super().__init__(host='localhost', port=port, username='zenith_admin') - + super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres') self.env = env self.running = False self.node_name: Optional[str] = None # dubious, see asserts below self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA self.tenant_id = tenant_id + self.port = port # path to conf is /pgdatadirs/tenants///postgresql.conf def create( From 4a8c66345267bfb11882a10d0260e2aacec6d112 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Thu, 14 Apr 2022 13:31:42 +0300 Subject: [PATCH 15/19] Refactor pgbench tests. - Remove batch_others/test_pgbench.py. It was a quick check that pgbench works, without actually recording any performance numbers, but that doesn't seem very interesting anymore. Remove it to avoid confusing it with the actual pgbench benchmarks - Run pgbench with "-n" and "-S" options, for two different workloads: simple-updates, and SELECT-only. Previously, we would only run it with the "default" TPCB-like workload. That's more or less the same as the simple-update (-n) workload, but I think the simple-upload workload is more relevant for testing storage performance. The SELECT-only workload is a new thing to measure. - Merge test_perf_pgbench.py and test_perf_pgbench_remote.py. I added a new "remote" implementation of the PgCompare class, which allows running the same tests against an already-running Postgres instance. - Make the PgBenchRunResult.parse_from_output function more flexible. pgbench can print different lines depending on the command-line options, but the parsing function expected a particular set of lines. --- .github/workflows/benchmarking.yml | 13 +- test_runner/batch_others/test_pgbench.py | 14 -- test_runner/fixtures/benchmark_fixture.py | 145 ++++++++---------- test_runner/fixtures/compare_fixtures.py | 49 +++++- test_runner/fixtures/zenith_fixtures.py | 68 ++++++-- test_runner/performance/test_perf_pgbench.py | 116 ++++++++++++-- .../performance/test_perf_pgbench_remote.py | 124 --------------- 7 files changed, 279 insertions(+), 250 deletions(-) delete mode 100644 test_runner/batch_others/test_pgbench.py delete mode 100644 test_runner/performance/test_perf_pgbench_remote.py diff --git a/.github/workflows/benchmarking.yml b/.github/workflows/benchmarking.yml index 36df35297d..72041c9d02 100644 --- a/.github/workflows/benchmarking.yml +++ b/.github/workflows/benchmarking.yml @@ -26,7 +26,7 @@ jobs: runs-on: [self-hosted, zenith-benchmarker] env: - PG_BIN: "/usr/pgsql-13/bin" + POSTGRES_DISTRIB_DIR: "/usr/pgsql-13" steps: - name: Checkout zenith repo @@ -51,7 +51,7 @@ jobs: echo Poetry poetry --version echo Pgbench - $PG_BIN/pgbench --version + $POSTGRES_DISTRIB_DIR/bin/pgbench --version # FIXME cluster setup is skipped due to various changes in console API # for now pre created cluster is used. When API gain some stability @@ -66,7 +66,7 @@ jobs: echo "Starting cluster" # wake up the cluster - $PG_BIN/psql $BENCHMARK_CONNSTR -c "SELECT 1" + $POSTGRES_DISTRIB_DIR/bin/psql $BENCHMARK_CONNSTR -c "SELECT 1" - name: Run benchmark # pgbench is installed system wide from official repo @@ -83,8 +83,11 @@ jobs: # sudo yum install postgresql13-contrib # actual binaries are located in /usr/pgsql-13/bin/ env: - TEST_PG_BENCH_TRANSACTIONS_MATRIX: "5000,10000,20000" - TEST_PG_BENCH_SCALES_MATRIX: "10,15" + # The pgbench test runs two tests of given duration against each scale. + # So the total runtime with these parameters is 2 * 2 * 300 = 1200, or 20 minutes. + # Plus time needed to initialize the test databases. + TEST_PG_BENCH_DURATIONS_MATRIX: "300" + TEST_PG_BENCH_SCALES_MATRIX: "10,100" PLATFORM: "zenith-staging" BENCHMARK_CONNSTR: "${{ secrets.BENCHMARK_STAGING_CONNSTR }}" REMOTE_ENV: "1" # indicate to test harness that we do not have zenith binaries locally diff --git a/test_runner/batch_others/test_pgbench.py b/test_runner/batch_others/test_pgbench.py deleted file mode 100644 index 09713023bc..0000000000 --- a/test_runner/batch_others/test_pgbench.py +++ /dev/null @@ -1,14 +0,0 @@ -from fixtures.zenith_fixtures import ZenithEnv -from fixtures.log_helper import log - - -def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin): - env = zenith_simple_env - env.zenith_cli.create_branch("test_pgbench", "empty") - pg = env.postgres.create_start('test_pgbench') - log.info("postgres is running on 'test_pgbench' branch") - - connstr = pg.connstr() - - pg_bin.run_capture(['pgbench', '-i', connstr]) - pg_bin.run_capture(['pgbench'] + '-c 10 -T 5 -P 1 -M prepared'.split() + [connstr]) diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index 480eb3f891..a904233e98 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -17,7 +17,7 @@ import warnings from contextlib import contextmanager # Type-related stuff -from typing import Iterator +from typing import Iterator, Optional """ This file contains fixtures for micro-benchmarks. @@ -51,17 +51,12 @@ in the test initialization, or measure disk usage after the test query. @dataclasses.dataclass class PgBenchRunResult: - scale: int number_of_clients: int number_of_threads: int number_of_transactions_actually_processed: int latency_average: float - latency_stddev: float - tps_including_connection_time: float - tps_excluding_connection_time: float - init_duration: float - init_start_timestamp: int - init_end_timestamp: int + latency_stddev: Optional[float] + tps: float run_duration: float run_start_timestamp: int run_end_timestamp: int @@ -69,56 +64,67 @@ class PgBenchRunResult: # TODO progress @classmethod - def parse_from_output( + def parse_from_stdout( cls, - out: 'subprocess.CompletedProcess[str]', - init_duration: float, - init_start_timestamp: int, - init_end_timestamp: int, + stdout: str, run_duration: float, run_start_timestamp: int, run_end_timestamp: int, ): - stdout_lines = out.stdout.splitlines() + stdout_lines = stdout.splitlines() + + latency_stddev = None + # we know significant parts of these values from test input # but to be precise take them from output - # scaling factor: 5 - assert "scaling factor" in stdout_lines[1] - scale = int(stdout_lines[1].split()[-1]) - # number of clients: 1 - assert "number of clients" in stdout_lines[3] - number_of_clients = int(stdout_lines[3].split()[-1]) - # number of threads: 1 - assert "number of threads" in stdout_lines[4] - number_of_threads = int(stdout_lines[4].split()[-1]) - # number of transactions actually processed: 1000/1000 - assert "number of transactions actually processed" in stdout_lines[6] - number_of_transactions_actually_processed = int(stdout_lines[6].split("/")[1]) - # latency average = 19.894 ms - assert "latency average" in stdout_lines[7] - latency_average = stdout_lines[7].split()[-2] - # latency stddev = 3.387 ms - assert "latency stddev" in stdout_lines[8] - latency_stddev = stdout_lines[8].split()[-2] - # tps = 50.219689 (including connections establishing) - assert "(including connections establishing)" in stdout_lines[9] - tps_including_connection_time = stdout_lines[9].split()[2] - # tps = 50.264435 (excluding connections establishing) - assert "(excluding connections establishing)" in stdout_lines[10] - tps_excluding_connection_time = stdout_lines[10].split()[2] + for line in stdout.splitlines(): + # scaling factor: 5 + if line.startswith("scaling factor:"): + scale = int(line.split()[-1]) + # number of clients: 1 + if line.startswith("number of clients: "): + number_of_clients = int(line.split()[-1]) + # number of threads: 1 + if line.startswith("number of threads: "): + number_of_threads = int(line.split()[-1]) + # number of transactions actually processed: 1000/1000 + # OR + # number of transactions actually processed: 1000 + if line.startswith("number of transactions actually processed"): + if "/" in line: + number_of_transactions_actually_processed = int(line.split("/")[1]) + else: + number_of_transactions_actually_processed = int(line.split()[-1]) + # latency average = 19.894 ms + if line.startswith("latency average"): + latency_average = float(line.split()[-2]) + # latency stddev = 3.387 ms + # (only printed with some options) + if line.startswith("latency stddev"): + latency_stddev = float(line.split()[-2]) + + # Get the TPS without initial connection time. The format + # of the tps lines changed in pgbench v14, but we accept + # either format: + # + # pgbench v13 and below: + # tps = 50.219689 (including connections establishing) + # tps = 50.264435 (excluding connections establishing) + # + # pgbench v14: + # initial connection time = 3.858 ms + # tps = 309.281539 (without initial connection time) + if (line.startswith("tps = ") and ("(excluding connections establishing)" in line + or "(without initial connection time)")): + tps = float(line.split()[2]) return cls( - scale=scale, number_of_clients=number_of_clients, number_of_threads=number_of_threads, number_of_transactions_actually_processed=number_of_transactions_actually_processed, - latency_average=float(latency_average), - latency_stddev=float(latency_stddev), - tps_including_connection_time=float(tps_including_connection_time), - tps_excluding_connection_time=float(tps_excluding_connection_time), - init_duration=init_duration, - init_start_timestamp=init_start_timestamp, - init_end_timestamp=init_end_timestamp, + latency_average=latency_average, + latency_stddev=latency_stddev, + tps=tps, run_duration=run_duration, run_start_timestamp=run_start_timestamp, run_end_timestamp=run_end_timestamp, @@ -187,60 +193,41 @@ class ZenithBenchmarker: report=MetricReport.LOWER_IS_BETTER, ) - def record_pg_bench_result(self, pg_bench_result: PgBenchRunResult): - self.record("scale", pg_bench_result.scale, '', MetricReport.TEST_PARAM) - self.record("number_of_clients", + def record_pg_bench_result(self, prefix: str, pg_bench_result: PgBenchRunResult): + self.record(f"{prefix}.number_of_clients", pg_bench_result.number_of_clients, '', MetricReport.TEST_PARAM) - self.record("number_of_threads", + self.record(f"{prefix}.number_of_threads", pg_bench_result.number_of_threads, '', MetricReport.TEST_PARAM) self.record( - "number_of_transactions_actually_processed", + f"{prefix}.number_of_transactions_actually_processed", pg_bench_result.number_of_transactions_actually_processed, '', # thats because this is predefined by test matrix and doesnt change across runs report=MetricReport.TEST_PARAM, ) - self.record("latency_average", + self.record(f"{prefix}.latency_average", pg_bench_result.latency_average, unit="ms", report=MetricReport.LOWER_IS_BETTER) - self.record("latency_stddev", - pg_bench_result.latency_stddev, - unit="ms", - report=MetricReport.LOWER_IS_BETTER) - self.record("tps_including_connection_time", - pg_bench_result.tps_including_connection_time, - '', - report=MetricReport.HIGHER_IS_BETTER) - self.record("tps_excluding_connection_time", - pg_bench_result.tps_excluding_connection_time, - '', - report=MetricReport.HIGHER_IS_BETTER) - self.record("init_duration", - pg_bench_result.init_duration, - unit="s", - report=MetricReport.LOWER_IS_BETTER) - self.record("init_start_timestamp", - pg_bench_result.init_start_timestamp, - '', - MetricReport.TEST_PARAM) - self.record("init_end_timestamp", - pg_bench_result.init_end_timestamp, - '', - MetricReport.TEST_PARAM) - self.record("run_duration", + if pg_bench_result.latency_stddev is not None: + self.record(f"{prefix}.latency_stddev", + pg_bench_result.latency_stddev, + unit="ms", + report=MetricReport.LOWER_IS_BETTER) + self.record(f"{prefix}.tps", pg_bench_result.tps, '', report=MetricReport.HIGHER_IS_BETTER) + self.record(f"{prefix}.run_duration", pg_bench_result.run_duration, unit="s", report=MetricReport.LOWER_IS_BETTER) - self.record("run_start_timestamp", + self.record(f"{prefix}.run_start_timestamp", pg_bench_result.run_start_timestamp, '', MetricReport.TEST_PARAM) - self.record("run_end_timestamp", + self.record(f"{prefix}.run_end_timestamp", pg_bench_result.run_end_timestamp, '', MetricReport.TEST_PARAM) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 598ee10f8e..3c6a923587 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -2,7 +2,7 @@ import pytest from contextlib import contextmanager from abc import ABC, abstractmethod -from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, ZenithEnv +from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, RemotePostgres, ZenithEnv from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker # Type-related stuff @@ -162,6 +162,48 @@ class VanillaCompare(PgCompare): return self.zenbenchmark.record_duration(out_name) +class RemoteCompare(PgCompare): + """PgCompare interface for a remote postgres instance.""" + def __init__(self, zenbenchmark, remote_pg: RemotePostgres): + self._pg = remote_pg + self._zenbenchmark = zenbenchmark + + # Long-lived cursor, useful for flushing + self.conn = self.pg.connect() + self.cur = self.conn.cursor() + + @property + def pg(self): + return self._pg + + @property + def zenbenchmark(self): + return self._zenbenchmark + + @property + def pg_bin(self): + return self._pg.pg_bin + + def flush(self): + # TODO: flush the remote pageserver + pass + + def report_peak_memory_use(self) -> None: + # TODO: get memory usage from remote pageserver + pass + + def report_size(self) -> None: + # TODO: get storage size from remote pageserver + pass + + @contextmanager + def record_pageserver_writes(self, out_name): + yield # Do nothing + + def record_duration(self, out_name): + return self.zenbenchmark.record_duration(out_name) + + @pytest.fixture(scope='function') def zenith_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> ZenithCompare: branch_name = request.node.name @@ -173,6 +215,11 @@ def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare: return VanillaCompare(zenbenchmark, vanilla_pg) +@pytest.fixture(scope='function') +def remote_compare(zenbenchmark, remote_pg) -> RemoteCompare: + return RemoteCompare(zenbenchmark, remote_pg) + + @pytest.fixture(params=["vanilla_compare", "zenith_compare"], ids=["vanilla", "zenith"]) def zenith_with_baseline(request) -> PgCompare: """Parameterized fixture that helps compare zenith against vanilla postgres. diff --git a/test_runner/fixtures/zenith_fixtures.py b/test_runner/fixtures/zenith_fixtures.py index 41d1443880..f8ee39a5a1 100644 --- a/test_runner/fixtures/zenith_fixtures.py +++ b/test_runner/fixtures/zenith_fixtures.py @@ -123,6 +123,22 @@ def pytest_configure(config): top_output_dir = os.path.join(base_dir, DEFAULT_OUTPUT_DIR) mkdir_if_needed(top_output_dir) + # Find the postgres installation. + global pg_distrib_dir + env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR') + if env_postgres_bin: + pg_distrib_dir = env_postgres_bin + else: + pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR)) + log.info(f'pg_distrib_dir is {pg_distrib_dir}') + if os.getenv("REMOTE_ENV"): + # When testing against a remote server, we only need the client binary. + if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/psql')): + raise Exception('psql not found at "{}"'.format(pg_distrib_dir)) + else: + if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')): + raise Exception('postgres not found at "{}"'.format(pg_distrib_dir)) + if os.getenv("REMOTE_ENV"): # we are in remote env and do not have zenith binaries locally # this is the case for benchmarks run on self-hosted runner @@ -138,17 +154,6 @@ def pytest_configure(config): if not os.path.exists(os.path.join(zenith_binpath, 'pageserver')): raise Exception('zenith binaries not found at "{}"'.format(zenith_binpath)) - # Find the postgres installation. - global pg_distrib_dir - env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR') - if env_postgres_bin: - pg_distrib_dir = env_postgres_bin - else: - pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR)) - log.info(f'pg_distrib_dir is {pg_distrib_dir}') - if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')): - raise Exception('postgres not found at "{}"'.format(pg_distrib_dir)) - def zenfixture(func: Fn) -> Fn: """ @@ -1305,6 +1310,47 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]: yield vanilla_pg +class RemotePostgres(PgProtocol): + def __init__(self, pg_bin: PgBin, remote_connstr: str): + super().__init__(**parse_dsn(remote_connstr)) + self.pg_bin = pg_bin + # The remote server is assumed to be running already + self.running = True + + def configure(self, options: List[str]): + raise Exception('cannot change configuration of remote Posgres instance') + + def start(self): + raise Exception('cannot start a remote Postgres instance') + + def stop(self): + raise Exception('cannot stop a remote Postgres instance') + + def get_subdir_size(self, subdir) -> int: + # TODO: Could use the server's Generic File Acccess functions if superuser. + # See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE + raise Exception('cannot get size of a Postgres instance') + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + # do nothing + pass + + +@pytest.fixture(scope='function') +def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]: + pg_bin = PgBin(test_output_dir) + + connstr = os.getenv("BENCHMARK_CONNSTR") + if connstr is None: + raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable") + + with RemotePostgres(pg_bin, connstr) as remote_pg: + yield remote_pg + + class ZenithProxy(PgProtocol): def __init__(self, port: int): super().__init__(host="127.0.0.1", diff --git a/test_runner/performance/test_perf_pgbench.py b/test_runner/performance/test_perf_pgbench.py index 5ffce3c0be..d2de76913a 100644 --- a/test_runner/performance/test_perf_pgbench.py +++ b/test_runner/performance/test_perf_pgbench.py @@ -2,29 +2,113 @@ from contextlib import closing from fixtures.zenith_fixtures import PgBin, VanillaPostgres, ZenithEnv from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare -from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker +from fixtures.benchmark_fixture import PgBenchRunResult, MetricReport, ZenithBenchmarker from fixtures.log_helper import log +from pathlib import Path + +import pytest +from datetime import datetime +import calendar +import os +import timeit + + +def utc_now_timestamp() -> int: + return calendar.timegm(datetime.utcnow().utctimetuple()) + + +def init_pgbench(env: PgCompare, cmdline): + # calculate timestamps and durations separately + # timestamp is intended to be used for linking to grafana and logs + # duration is actually a metric and uses float instead of int for timestamp + init_start_timestamp = utc_now_timestamp() + t0 = timeit.default_timer() + with env.record_pageserver_writes('init.pageserver_writes'): + env.pg_bin.run_capture(cmdline) + env.flush() + init_duration = timeit.default_timer() - t0 + init_end_timestamp = utc_now_timestamp() + + env.zenbenchmark.record("init.duration", + init_duration, + unit="s", + report=MetricReport.LOWER_IS_BETTER) + env.zenbenchmark.record("init.start_timestamp", + init_start_timestamp, + '', + MetricReport.TEST_PARAM) + env.zenbenchmark.record("init.end_timestamp", init_end_timestamp, '', MetricReport.TEST_PARAM) + + +def run_pgbench(env: PgCompare, prefix: str, cmdline): + with env.record_pageserver_writes(f'{prefix}.pageserver_writes'): + run_start_timestamp = utc_now_timestamp() + t0 = timeit.default_timer() + out = env.pg_bin.run_capture(cmdline, ) + run_duration = timeit.default_timer() - t0 + run_end_timestamp = utc_now_timestamp() + env.flush() + + stdout = Path(f"{out}.stdout").read_text() + + res = PgBenchRunResult.parse_from_stdout( + stdout=stdout, + run_duration=run_duration, + run_start_timestamp=run_start_timestamp, + run_end_timestamp=run_end_timestamp, + ) + env.zenbenchmark.record_pg_bench_result(prefix, res) + # -# Run a very short pgbench test. +# Initialize a pgbench database, and run pgbench against it. # -# Collects three metrics: +# This makes runs two different pgbench workloads against the same +# initialized database, and 'duration' is the time of each run. So +# the total runtime is 2 * duration, plus time needed to initialize +# the test database. # -# 1. Time to initialize the pgbench database (pgbench -s5 -i) -# 2. Time to run 5000 pgbench transactions -# 3. Disk space used -# -def test_pgbench(zenith_with_baseline: PgCompare): - env = zenith_with_baseline +# Currently, the # of connections is hardcoded at 4 +def run_test_pgbench(env: PgCompare, scale: int, duration: int): - with env.record_pageserver_writes('pageserver_writes'): - with env.record_duration('init'): - env.pg_bin.run_capture(['pgbench', '-s5', '-i', env.pg.connstr()]) - env.flush() + # Record the scale and initialize + env.zenbenchmark.record("scale", scale, '', MetricReport.TEST_PARAM) + init_pgbench(env, ['pgbench', f'-s{scale}', '-i', env.pg.connstr()]) - with env.record_duration('5000_xacts'): - env.pg_bin.run_capture(['pgbench', '-c1', '-t5000', env.pg.connstr()]) - env.flush() + # Run simple-update workload + run_pgbench(env, + "simple-update", + ['pgbench', '-n', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()]) + + # Run SELECT workload + run_pgbench(env, + "select-only", + ['pgbench', '-S', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()]) env.report_size() + + +def get_durations_matrix(): + durations = os.getenv("TEST_PG_BENCH_DURATIONS_MATRIX", default="45") + return list(map(int, durations.split(","))) + + +def get_scales_matrix(): + scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX", default="10") + return list(map(int, scales.split(","))) + + +# Run the pgbench tests against vanilla Postgres and zenith +@pytest.mark.parametrize("scale", get_scales_matrix()) +@pytest.mark.parametrize("duration", get_durations_matrix()) +def test_pgbench(zenith_with_baseline: PgCompare, scale: int, duration: int): + run_test_pgbench(zenith_with_baseline, scale, duration) + + +# Run the pgbench tests against an existing Postgres cluster +@pytest.mark.parametrize("scale", get_scales_matrix()) +@pytest.mark.parametrize("duration", get_durations_matrix()) +@pytest.mark.remote_cluster +def test_pgbench_remote(remote_compare: PgCompare, scale: int, duration: int): + run_test_pgbench(remote_compare, scale, duration) diff --git a/test_runner/performance/test_perf_pgbench_remote.py b/test_runner/performance/test_perf_pgbench_remote.py deleted file mode 100644 index 28472a16c8..0000000000 --- a/test_runner/performance/test_perf_pgbench_remote.py +++ /dev/null @@ -1,124 +0,0 @@ -import dataclasses -import os -import subprocess -from typing import List -from fixtures.benchmark_fixture import PgBenchRunResult, ZenithBenchmarker -import pytest -from datetime import datetime -import calendar -import timeit -import os - - -def utc_now_timestamp() -> int: - return calendar.timegm(datetime.utcnow().utctimetuple()) - - -@dataclasses.dataclass -class PgBenchRunner: - connstr: str - scale: int - transactions: int - pgbench_bin_path: str = "pgbench" - - def invoke(self, args: List[str]) -> 'subprocess.CompletedProcess[str]': - res = subprocess.run([self.pgbench_bin_path, *args], text=True, capture_output=True) - - if res.returncode != 0: - raise RuntimeError(f"pgbench failed. stdout: {res.stdout} stderr: {res.stderr}") - return res - - def init(self, vacuum: bool = True) -> 'subprocess.CompletedProcess[str]': - args = [] - if not vacuum: - args.append("--no-vacuum") - args.extend([f"--scale={self.scale}", "--initialize", self.connstr]) - return self.invoke(args) - - def run(self, jobs: int = 1, clients: int = 1): - return self.invoke([ - f"--transactions={self.transactions}", - f"--jobs={jobs}", - f"--client={clients}", - "--progress=2", # print progress every two seconds - self.connstr, - ]) - - -@pytest.fixture -def connstr(): - res = os.getenv("BENCHMARK_CONNSTR") - if res is None: - raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable") - return res - - -def get_transactions_matrix(): - transactions = os.getenv("TEST_PG_BENCH_TRANSACTIONS_MATRIX") - if transactions is None: - return [10**4, 10**5] - return list(map(int, transactions.split(","))) - - -def get_scales_matrix(): - scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX") - if scales is None: - return [10, 20] - return list(map(int, scales.split(","))) - - -@pytest.mark.parametrize("scale", get_scales_matrix()) -@pytest.mark.parametrize("transactions", get_transactions_matrix()) -@pytest.mark.remote_cluster -def test_pg_bench_remote_cluster(zenbenchmark: ZenithBenchmarker, - connstr: str, - scale: int, - transactions: int): - """ - The best way is to run same pack of tests both, for local zenith - and against staging, but currently local tests heavily depend on - things available only locally e.g. zenith binaries, pageserver api, etc. - Also separate test allows to run pgbench workload against vanilla postgres - or other systems that support postgres protocol. - - Also now this is more of a liveness test because it stresses pageserver internals, - so we clearly see what goes wrong in more "real" environment. - """ - pg_bin = os.getenv("PG_BIN") - if pg_bin is not None: - pgbench_bin_path = os.path.join(pg_bin, "pgbench") - else: - pgbench_bin_path = "pgbench" - - runner = PgBenchRunner( - connstr=connstr, - scale=scale, - transactions=transactions, - pgbench_bin_path=pgbench_bin_path, - ) - # calculate timestamps and durations separately - # timestamp is intended to be used for linking to grafana and logs - # duration is actually a metric and uses float instead of int for timestamp - init_start_timestamp = utc_now_timestamp() - t0 = timeit.default_timer() - runner.init() - init_duration = timeit.default_timer() - t0 - init_end_timestamp = utc_now_timestamp() - - run_start_timestamp = utc_now_timestamp() - t0 = timeit.default_timer() - out = runner.run() # TODO handle failures - run_duration = timeit.default_timer() - t0 - run_end_timestamp = utc_now_timestamp() - - res = PgBenchRunResult.parse_from_output( - out=out, - init_duration=init_duration, - init_start_timestamp=init_start_timestamp, - init_end_timestamp=init_end_timestamp, - run_duration=run_duration, - run_start_timestamp=run_start_timestamp, - run_end_timestamp=run_end_timestamp, - ) - - zenbenchmark.record_pg_bench_result(res) From 9e4de6bed02e9dc48af5b9d74a7759b0c2702b26 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Tue, 12 Apr 2022 20:29:35 +0300 Subject: [PATCH 16/19] Use RwLock instad of Mutex for layer map lock. For more concurrency --- pageserver/src/layered_repository.rs | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index e178ba5222..95df385cfe 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -193,7 +193,7 @@ impl Repository for LayeredRepository { Arc::clone(&self.walredo_mgr), self.upload_layers, ); - timeline.layers.lock().unwrap().next_open_layer_at = Some(initdb_lsn); + timeline.layers.write().unwrap().next_open_layer_at = Some(initdb_lsn); let timeline = Arc::new(timeline); let r = timelines.insert( @@ -725,7 +725,7 @@ pub struct LayeredTimeline { tenantid: ZTenantId, timelineid: ZTimelineId, - layers: Mutex, + layers: RwLock, last_freeze_at: AtomicLsn, @@ -997,7 +997,7 @@ impl LayeredTimeline { conf, timelineid, tenantid, - layers: Mutex::new(LayerMap::default()), + layers: RwLock::new(LayerMap::default()), walredo_mgr, @@ -1040,7 +1040,7 @@ impl LayeredTimeline { /// Returns all timeline-related files that were found and loaded. /// fn load_layer_map(&self, disk_consistent_lsn: Lsn) -> anyhow::Result<()> { - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); let mut num_layers = 0; // Scan timeline directory and create ImageFileName and DeltaFilename @@ -1194,7 +1194,7 @@ impl LayeredTimeline { continue; } - let layers = timeline.layers.lock().unwrap(); + let layers = timeline.layers.read().unwrap(); // Check the open and frozen in-memory layers first if let Some(open_layer) = &layers.open_layer { @@ -1276,7 +1276,7 @@ impl LayeredTimeline { /// Get a handle to the latest layer for appending. /// fn get_layer_for_write(&self, lsn: Lsn) -> anyhow::Result> { - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); ensure!(lsn.is_aligned()); @@ -1347,7 +1347,7 @@ impl LayeredTimeline { } else { Some(self.write_lock.lock().unwrap()) }; - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); if let Some(open_layer) = &layers.open_layer { let open_layer_rc = Arc::clone(open_layer); // Does this layer need freezing? @@ -1412,7 +1412,7 @@ impl LayeredTimeline { let timer = self.flush_time_histo.start_timer(); loop { - let layers = self.layers.lock().unwrap(); + let layers = self.layers.read().unwrap(); if let Some(frozen_layer) = layers.frozen_layers.front() { let frozen_layer = Arc::clone(frozen_layer); drop(layers); // to allow concurrent reads and writes @@ -1456,7 +1456,7 @@ impl LayeredTimeline { // Finally, replace the frozen in-memory layer with the new on-disk layers { - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); let l = layers.frozen_layers.pop_front(); // Only one thread may call this function at a time (for this @@ -1612,7 +1612,7 @@ impl LayeredTimeline { lsn: Lsn, threshold: usize, ) -> Result { - let layers = self.layers.lock().unwrap(); + let layers = self.layers.read().unwrap(); for part_range in &partition.ranges { let image_coverage = layers.image_coverage(part_range, lsn)?; @@ -1670,7 +1670,7 @@ impl LayeredTimeline { // FIXME: Do we need to do something to upload it to remote storage here? - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); layers.insert_historic(Arc::new(image_layer)); drop(layers); @@ -1678,7 +1678,7 @@ impl LayeredTimeline { } fn compact_level0(&self, target_file_size: u64) -> Result<()> { - let layers = self.layers.lock().unwrap(); + let layers = self.layers.read().unwrap(); let level0_deltas = layers.get_level0_deltas()?; @@ -1768,7 +1768,7 @@ impl LayeredTimeline { layer_paths.pop().unwrap(); } - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); for l in new_layers { layers.insert_historic(Arc::new(l)); } @@ -1850,7 +1850,7 @@ impl LayeredTimeline { // 2. it doesn't need to be retained for 'retain_lsns'; // 3. newer on-disk image layers cover the layer's whole key range // - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); 'outer: for l in layers.iter_historic_layers() { // This layer is in the process of being flushed to disk. // It will be swapped out of the layer map, replaced with From d5ae9db997711d770b52511f8bbd2eef8067cedc Mon Sep 17 00:00:00 2001 From: bojanserafimov Date: Thu, 14 Apr 2022 10:09:03 -0400 Subject: [PATCH 17/19] Add s3 cost estimate to tests (#1478) --- pageserver/src/layered_repository.rs | 22 ++++++++++++++++- test_runner/fixtures/benchmark_fixture.py | 30 ++++++++++------------- test_runner/fixtures/compare_fixtures.py | 13 ++++++++++ 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index 95df385cfe..36b081e400 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -49,7 +49,8 @@ use crate::CheckpointConfig; use crate::{ZTenantId, ZTimelineId}; use zenith_metrics::{ - register_histogram_vec, register_int_gauge_vec, Histogram, HistogramVec, IntGauge, IntGaugeVec, + register_histogram_vec, register_int_counter, register_int_gauge_vec, Histogram, HistogramVec, + IntCounter, IntGauge, IntGaugeVec, }; use zenith_utils::crashsafe_dir; use zenith_utils::lsn::{AtomicLsn, Lsn, RecordLsn}; @@ -109,6 +110,21 @@ lazy_static! { .expect("failed to define a metric"); } +// Metrics for cloud upload. These metrics reflect data uploaded to cloud storage, +// or in testing they estimate how much we would upload if we did. +lazy_static! { + static ref NUM_PERSISTENT_FILES_CREATED: IntCounter = register_int_counter!( + "pageserver_num_persistent_files_created", + "Number of files created that are meant to be uploaded to cloud storage", + ) + .expect("failed to define a metric"); + static ref PERSISTENT_BYTES_WRITTEN: IntCounter = register_int_counter!( + "pageserver_persistent_bytes_written", + "Total bytes written that are meant to be uploaded to cloud storage", + ) + .expect("failed to define a metric"); +} + /// Parts of the `.zenith/tenants//timelines/` directory prefix. pub const TIMELINES_SEGMENT_NAME: &str = "timelines"; @@ -1524,6 +1540,10 @@ impl LayeredTimeline { &metadata, false, )?; + + NUM_PERSISTENT_FILES_CREATED.inc_by(1); + PERSISTENT_BYTES_WRITTEN.inc_by(new_delta_path.metadata()?.len()); + if self.upload_layers.load(atomic::Ordering::Relaxed) { schedule_timeline_checkpoint_upload( self.tenantid, diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index a904233e98..0735f16d73 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -236,10 +236,18 @@ class ZenithBenchmarker: """ Fetch the "cumulative # of bytes written" metric from the pageserver """ - # Fetch all the exposed prometheus metrics from page server - all_metrics = pageserver.http_client().get_metrics() - # Use a regular expression to extract the one we're interested in - # + metric_name = r'pageserver_disk_io_bytes{io_operation="write"}' + return self.get_int_counter_value(pageserver, metric_name) + + def get_peak_mem(self, pageserver) -> int: + """ + Fetch the "maxrss" metric from the pageserver + """ + metric_name = r'pageserver_maxrss_kb' + return self.get_int_counter_value(pageserver, metric_name) + + def get_int_counter_value(self, pageserver, metric_name) -> int: + """Fetch the value of given int counter from pageserver metrics.""" # TODO: If we start to collect more of the prometheus metrics in the # performance test suite like this, we should refactor this to load and # parse all the metrics into a more convenient structure in one go. @@ -247,20 +255,8 @@ class ZenithBenchmarker: # The metric should be an integer, as it's a number of bytes. But in general # all prometheus metrics are floats. So to be pedantic, read it as a float # and round to integer. - matches = re.search(r'^pageserver_disk_io_bytes{io_operation="write"} (\S+)$', - all_metrics, - re.MULTILINE) - assert matches - return int(round(float(matches.group(1)))) - - def get_peak_mem(self, pageserver) -> int: - """ - Fetch the "maxrss" metric from the pageserver - """ - # Fetch all the exposed prometheus metrics from page server all_metrics = pageserver.http_client().get_metrics() - # See comment in get_io_writes() - matches = re.search(r'^pageserver_maxrss_kb (\S+)$', all_metrics, re.MULTILINE) + matches = re.search(fr'^{metric_name} (\S+)$', all_metrics, re.MULTILINE) assert matches return int(round(float(matches.group(1)))) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 3c6a923587..93912d2da7 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -105,6 +105,19 @@ class ZenithCompare(PgCompare): 'MB', report=MetricReport.LOWER_IS_BETTER) + total_files = self.zenbenchmark.get_int_counter_value( + self.env.pageserver, "pageserver_num_persistent_files_created") + total_bytes = self.zenbenchmark.get_int_counter_value( + self.env.pageserver, "pageserver_persistent_bytes_written") + self.zenbenchmark.record("data_uploaded", + total_bytes / (1024 * 1024), + "MB", + report=MetricReport.LOWER_IS_BETTER) + self.zenbenchmark.record("num_files_uploaded", + total_files, + "", + report=MetricReport.LOWER_IS_BETTER) + def record_pageserver_writes(self, out_name): return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name) From 93e0ac2b7ae84747188d0da98061333b4a52a150 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Thu, 14 Apr 2022 16:17:47 +0300 Subject: [PATCH 18/19] Remove a couple of unused dependencies. Found by "cargo-udeps" --- Cargo.lock | 2 -- pageserver/Cargo.toml | 1 - proxy/Cargo.toml | 1 - 3 files changed, 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0584b9d6d2..5027c4bdc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1551,7 +1551,6 @@ dependencies = [ "tokio-util 0.7.0", "toml_edit", "tracing", - "tracing-futures", "url", "workspace_hack", "zenith_metrics", @@ -1938,7 +1937,6 @@ dependencies = [ "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "tokio-postgres-rustls", "tokio-rustls 0.22.0", - "tokio-stream", "workspace_hack", "zenith_metrics", "zenith_utils", diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index dccdca291c..e92ac0421c 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -37,7 +37,6 @@ toml_edit = { version = "0.13", features = ["easy"] } scopeguard = "1.1.0" const_format = "0.2.21" tracing = "0.1.27" -tracing-futures = "0.2" signal-hook = "0.3.10" url = "2" nix = "0.23" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 56b6dd7e20..be03a2d4a9 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -31,7 +31,6 @@ thiserror = "1.0.30" tokio = { version = "1.17", features = ["macros"] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } tokio-rustls = "0.22.0" -tokio-stream = "0.1.8" zenith_utils = { path = "../zenith_utils" } zenith_metrics = { path = "../zenith_metrics" } From 2cb39a162431716eeb835656c45ca1cff4eab544 Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Thu, 14 Apr 2022 14:04:45 +0300 Subject: [PATCH 19/19] add missing files, update workspace hack --- Cargo.lock | 8 ++++---- workspace_hack/.gitattributes | 4 ++++ workspace_hack/Cargo.toml | 16 +++++++++++----- workspace_hack/build.rs | 2 ++ 4 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 workspace_hack/.gitattributes create mode 100644 workspace_hack/build.rs diff --git a/Cargo.lock b/Cargo.lock index 5027c4bdc7..3a75687b36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2112,7 +2112,6 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-rustls 0.23.2", - "tokio-util 0.6.9", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -3390,19 +3389,20 @@ dependencies = [ "anyhow", "bytes", "cc", + "chrono", "clap 2.34.0", "either", "hashbrown", + "indexmap", "libc", "log", "memchr", "num-integer", "num-traits", - "proc-macro2", - "quote", + "prost", + "rand", "regex", "regex-syntax", - "reqwest", "scopeguard", "serde", "syn", diff --git a/workspace_hack/.gitattributes b/workspace_hack/.gitattributes new file mode 100644 index 0000000000..3e9dba4b64 --- /dev/null +++ b/workspace_hack/.gitattributes @@ -0,0 +1,4 @@ +# Avoid putting conflict markers in the generated Cargo.toml file, since their presence breaks +# Cargo. +# Also do not check out the file as CRLF on Windows, as that's what hakari needs. +Cargo.toml merge=binary -crlf diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 6e6a0e09d7..84244b3363 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -16,32 +16,38 @@ publish = false [dependencies] anyhow = { version = "1", features = ["backtrace", "std"] } bytes = { version = "1", features = ["serde", "std"] } +chrono = { version = "0.4", features = ["clock", "libc", "oldtime", "serde", "std", "time", "winapi"] } clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] } either = { version = "1", features = ["use_std"] } hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] } +indexmap = { version = "1", default-features = false, features = ["std"] } libc = { version = "0.2", features = ["extra_traits", "std"] } log = { version = "0.4", default-features = false, features = ["serde", "std"] } memchr = { version = "2", features = ["std", "use_std"] } num-integer = { version = "0.1", default-features = false, features = ["std"] } num-traits = { version = "0.2", features = ["std"] } +prost = { version = "0.9", features = ["prost-derive", "std"] } +rand = { version = "0.8", features = ["alloc", "getrandom", "libc", "rand_chacha", "rand_hc", "small_rng", "std", "std_rng"] } regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } -reqwest = { version = "0.11", default-features = false, features = ["__rustls", "__tls", "blocking", "hyper-rustls", "json", "rustls", "rustls-pemfile", "rustls-tls", "rustls-tls-webpki-roots", "serde_json", "stream", "tokio-rustls", "tokio-util", "webpki-roots"] } scopeguard = { version = "1", features = ["use_std"] } serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] } -tokio = { version = "1", features = ["bytes", "fs", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "sync", "time", "tokio-macros"] } -tracing = { version = "0.1", features = ["attributes", "std", "tracing-attributes"] } +tokio = { version = "1", features = ["bytes", "fs", "io-std", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "socket2", "sync", "time", "tokio-macros"] } +tracing = { version = "0.1", features = ["attributes", "log", "std", "tracing-attributes"] } tracing-core = { version = "0.1", features = ["lazy_static", "std"] } [build-dependencies] +anyhow = { version = "1", features = ["backtrace", "std"] } +bytes = { version = "1", features = ["serde", "std"] } cc = { version = "1", default-features = false, features = ["jobserver", "parallel"] } clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] } either = { version = "1", features = ["use_std"] } +hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] } +indexmap = { version = "1", default-features = false, features = ["std"] } libc = { version = "0.2", features = ["extra_traits", "std"] } log = { version = "0.4", default-features = false, features = ["serde", "std"] } memchr = { version = "2", features = ["std", "use_std"] } -proc-macro2 = { version = "1", features = ["proc-macro"] } -quote = { version = "1", features = ["proc-macro"] } +prost = { version = "0.9", features = ["prost-derive", "std"] } regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] } diff --git a/workspace_hack/build.rs b/workspace_hack/build.rs new file mode 100644 index 0000000000..92518ef04c --- /dev/null +++ b/workspace_hack/build.rs @@ -0,0 +1,2 @@ +// A build script is required for cargo to consider build dependencies. +fn main() {}