Merge branch 'main' into bojan-get-page-tests

This commit is contained in:
Bojan Serafimov
2022-04-27 13:05:27 -04:00
187 changed files with 5704 additions and 4343 deletions

View File

@@ -12,7 +12,7 @@ use crate::waiters;
use std::io;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
pub use credentials::ClientCredentials;

View File

@@ -48,10 +48,6 @@ impl ClientCredentials {
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<DatabaseInfo, AuthError> {
fail::fail_point!("proxy-authenticate", |_| {
Err(AuthError::auth_failed("failpoint triggered"))
});
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
match &config.router_config {

View File

@@ -5,7 +5,7 @@ use crate::stream::PqStream;
use crate::{sasl, scram};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
/// Every authentication selector is supposed to implement this trait.
pub trait AuthMethod {

View File

@@ -4,7 +4,7 @@ use parking_lot::Mutex;
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use zenith_utils::pq_proto::CancelKeyData;
use utils::pq_proto::CancelKeyData;
/// Enables serving `CancelRequest`s.
#[derive(Default)]

View File

@@ -1,10 +1,9 @@
use anyhow::{anyhow, bail, ensure, Context};
use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig};
use anyhow::{bail, ensure, Context};
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
pub type TlsConfig = Arc<ServerConfig>;
pub type TlsConfig = Arc<rustls::ServerConfig>;
#[non_exhaustive]
pub enum ClientAuthMethod {
@@ -61,21 +60,28 @@ pub struct ProxyConfig {
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
let key = {
let key_bytes = std::fs::read(key_path).context("SSL key file")?;
let mut keys = pemfile::pkcs8_private_keys(&mut &key_bytes[..])
.map_err(|_| anyhow!("couldn't read TLS keys"))?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
.context("couldn't read TLS keys")?;
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
keys.pop().unwrap()
keys.pop().map(rustls::PrivateKey).unwrap()
};
let cert_chain = {
let cert_chain_bytes = std::fs::read(cert_path).context("SSL cert file")?;
pemfile::certs(&mut &cert_chain_bytes[..])
.map_err(|_| anyhow!("couldn't read TLS certificates"))?
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.context("couldn't read TLS certificate chain")?
.into_iter()
.map(rustls::Certificate)
.collect()
};
let mut config = ServerConfig::new(NoClientAuth::new());
config.set_single_cert(cert_chain, key)?;
config.versions = vec![ProtocolVersion::TLSv1_3];
let config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13])?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?;
Ok(config.into())
}

View File

@@ -1,10 +1,7 @@
use anyhow::anyhow;
use hyper::{Body, Request, Response, StatusCode};
use std::net::TcpListener;
use zenith_utils::http::endpoint;
use zenith_utils::http::error::ApiError;
use zenith_utils::http::json::json_response;
use zenith_utils::http::{RouterBuilder, RouterService};
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "")

View File

@@ -30,7 +30,7 @@ use config::ProxyConfig;
use futures::FutureExt;
use std::future::Future;
use tokio::{net::TcpListener, task::JoinError};
use zenith_utils::GIT_VERSION;
use utils::GIT_VERSION;
use crate::config::{ClientAuthMethod, RouterConfig};
@@ -43,7 +43,7 @@ async fn flatten_err(
#[tokio::main]
async fn main() -> anyhow::Result<()> {
zenith_metrics::set_common_metrics_prefix("zenith_proxy");
metrics::set_common_metrics_prefix("zenith_proxy");
let arg_matches = Command::new("Zenith proxy/router")
.version(GIT_VERSION)
.arg(

View File

@@ -5,7 +5,7 @@ use std::{
net::{TcpListener, TcpStream},
thread,
};
use zenith_utils::{
use utils::{
postgres_backend::{self, AuthType, PostgresBackend},
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
};

View File

@@ -5,10 +5,10 @@ use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use futures::TryFutureExt;
use lazy_static::lazy_static;
use metrics::{new_common_metric_name, register_int_counter, IntCounter};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter};
use zenith_utils::pq_proto::{BeMessage as Be, *};
use utils::pq_proto::{BeMessage as Be, *};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
@@ -265,14 +265,24 @@ mod tests {
let (ca, cert, key) = generate_certs(hostname)?;
let server_config = {
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)?;
config.into()
};
let client_config = {
let mut config = rustls::ClientConfig::new();
config.root_store.add(&ca)?;
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&ca)?;
store
})
.with_no_client_auth();
ClientConfig { config, hostname }
};

View File

@@ -1,9 +1,9 @@
//! Definitions for SASL messages.
use crate::parse::{split_at_const, split_cstr};
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage};
use utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage};
/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage).
/// SASL-specific payload of [`PasswordMessage`](utils::pq_proto::FeMessage::PasswordMessage).
#[derive(Debug)]
pub struct FirstMessage<'a> {
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
@@ -31,7 +31,7 @@ impl<'a> FirstMessage<'a> {
/// A single SASL message.
/// This struct is deliberately decoupled from lower-level
/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage).
/// [`BeAuthenticationSaslMessage`](utils::pq_proto::BeAuthenticationSaslMessage).
#[derive(Debug)]
pub(super) enum ServerMessage<T> {
/// We expect to see more steps.

View File

@@ -18,7 +18,7 @@ pub use secret::*;
pub use exchange::Exchange;
pub use secret::ServerSecret;
use hmac::{Hmac, Mac, NewMac};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
// TODO: add SCRAM-SHA-256-PLUS
@@ -40,7 +40,7 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_varkey(key).expect("bad key size");
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
// TODO: maybe newer `hmac` et al already migrated to regular arrays?

View File

@@ -9,7 +9,7 @@ use std::{io, task};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
use utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
pin_project! {
/// Stream wrapper which implements libpq's protocol.