diff --git a/Cargo.lock b/Cargo.lock index 1cfa70e9c1..84c0350763 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,17 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "739f4a8db6605981345c5654f3a85b056ce52f37a39d34da03f25bf2151ea16e" +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.18" @@ -769,7 +780,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" dependencies = [ - "ahash", + "ahash 0.4.7", ] [[package]] @@ -777,6 +788,9 @@ name = "hashbrown" version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.6", +] [[package]] name = "hermit-abi" @@ -896,7 +910,7 @@ dependencies = [ "hyper", "rustls 0.20.2", "tokio", - "tokio-rustls", + "tokio-rustls 0.23.2", ] [[package]] @@ -981,7 +995,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "afabcc15e437a6484fc4f12d0fd63068fe457bf93f1c148d3d9649c60b103f32" dependencies = [ "base64 0.12.3", - "pem", + "pem 0.8.3", "ring", "serde", "serde_json", @@ -1352,6 +1366,15 @@ dependencies = [ "regex", ] +[[package]] +name = "pem" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9a3b09a20e374558580a4914d3b7d89bd61b954a5a5e1dcbea98753addb1947" +dependencies = [ + "base64 0.13.0", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -1556,17 +1579,25 @@ dependencies = [ "anyhow", "bytes", "clap 3.0.14", + "futures", + "hashbrown 0.11.2", "hex", "hyper", "lazy_static", "md5", + "parking_lot", + "pin-project-lite", "rand", + "rcgen", "reqwest", "rustls 0.19.1", + "scopeguard", "serde", "serde_json", "tokio", "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", + "tokio-postgres-rustls", + "tokio-rustls 0.22.0", "zenith_metrics", "zenith_utils", ] @@ -1620,6 +1651,18 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rcgen" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5911d1403f4143c9d56a702069d593e8d0f3fab880a85e103604d0893ea31ba7" +dependencies = [ + "chrono", + "pem 1.0.2", + "ring", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.2.10" @@ -1703,7 +1746,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "tokio", - "tokio-rustls", + "tokio-rustls 0.23.2", "tokio-util", "url", "wasm-bindgen", @@ -2265,6 +2308,32 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-postgres-rustls" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bd8c37d8c23cb6ecdc32fc171bade4e9c7f1be65f693a17afbaad02091a0a19" +dependencies = [ + "futures", + "ring", + "rustls 0.19.1", + "tokio", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", + "tokio-rustls 0.22.0", + "webpki 0.21.4", +] + +[[package]] +name = "tokio-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls 0.19.1", + "tokio", + "webpki 0.21.4", +] + [[package]] name = "tokio-rustls" version = "0.23.2" @@ -2730,6 +2799,15 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2d7d3948613f75c98fd9328cfdcc45acc4d360655289d0a7d4ec931392200a3" +[[package]] +name = "yasna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e262a29d0e61ccf2b6190d7050d4b237535fc76ce4c1210d9caa316f71dffa75" +dependencies = [ + "chrono", +] + [[package]] name = "zenith" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 51bc709fe8..b20e64a06f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,8 @@ members = [ # This is useful for profiling and, to some extent, debug. # Besides, debug info should not affect the performance. debug = true + +# This is only needed for proxy's tests +# TODO: we should probably fork tokio-postgres-rustls instead +[patch.crates-io] +tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 6e50ac9853..d8d5cbe5bf 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -6,18 +6,28 @@ edition = "2021" [dependencies] anyhow = "1.0" bytes = { version = "1.0.1", features = ['serde'] } -lazy_static = "1.4.0" -md5 = "0.7.0" -rand = "0.8.3" +clap = "3.0" +futures = "0.3.13" +hashbrown = "0.11.2" hex = "0.4.3" hyper = "0.14" +lazy_static = "1.4.0" +md5 = "0.7.0" +parking_lot = "0.11.2" +pin-project-lite = "0.2.7" +rand = "0.8.3" +reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } +rustls = "0.19.1" +scopeguard = "1.1.0" serde = "1" serde_json = "1" tokio = { version = "1.11", features = ["macros"] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } -clap = "3.0" -rustls = "0.19.1" -reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } +tokio-rustls = "0.22.0" zenith_utils = { path = "../zenith_utils" } zenith_metrics = { path = "../zenith_metrics" } + +[dev-dependencies] +tokio-postgres-rustls = "0.8.0" +rcgen = "0.8.14" diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs new file mode 100644 index 0000000000..20beb6ac79 --- /dev/null +++ b/proxy/src/auth.rs @@ -0,0 +1,127 @@ +use crate::compute::DatabaseInfo; +use crate::config::ProxyConfig; +use crate::cplane_api::{self, CPlaneApi}; +use crate::stream::PqStream; +use anyhow::{anyhow, bail, Context}; +use std::collections::HashMap; +use tokio::io::{AsyncRead, AsyncWrite}; +use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe}; + +/// Various client credentials which we use for authentication. +#[derive(Debug, PartialEq, Eq)] +pub struct ClientCredentials { + pub user: String, + pub dbname: String, +} + +impl TryFrom> for ClientCredentials { + type Error = anyhow::Error; + + fn try_from(mut value: HashMap) -> Result { + let mut get_param = |key| { + value + .remove(key) + .with_context(|| format!("{} is missing in startup packet", key)) + }; + + let user = get_param("user")?; + let db = get_param("database")?; + + Ok(Self { user, dbname: db }) + } +} + +impl ClientCredentials { + /// Use credentials to authenticate the user. + pub async fn authenticate( + self, + config: &ProxyConfig, + client: &mut PqStream, + ) -> anyhow::Result { + let db_info = if self.user.ends_with("@zenith") { + handle_existing_user(config, client, self).await + } else { + handle_new_user(config, client).await + }; + + db_info.context("failed to authenticate client") + } +} + +fn new_psql_session_id() -> String { + hex::encode(rand::random::<[u8; 8]>()) +} + +async fn handle_existing_user( + config: &ProxyConfig, + client: &mut PqStream, + creds: ClientCredentials, +) -> anyhow::Result { + let psql_session_id = new_psql_session_id(); + let md5_salt = rand::random(); + + client + .write_message(&Be::AuthenticationMD5Password(&md5_salt)) + .await?; + + // Read client's password hash + let msg = match client.read_message().await? { + Fe::PasswordMessage(msg) => msg, + bad => bail!("unexpected message type: {:?}", bad), + }; + + let (_trailing_null, md5_response) = msg + .split_last() + .ok_or_else(|| anyhow!("unexpected password message"))?; + + let cplane = CPlaneApi::new(&config.auth_endpoint); + let db_info = cplane + .authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id) + .await?; + + client + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())?; + + Ok(db_info) +} + +async fn handle_new_user( + config: &ProxyConfig, + client: &mut PqStream, +) -> anyhow::Result { + let psql_session_id = new_psql_session_id(); + let greeting = hello_message(&config.redirect_uri, &psql_session_id); + + let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async { + // Give user a URL to spawn a new database + client + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message(&Be::NoticeResponse(greeting)) + .await?; + + // Wait for web console response + waiter.await?.map_err(|e| anyhow!(e)) + }) + .await?; + + client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?; + + Ok(db_info) +} + +fn hello_message(redirect_uri: &str, session_id: &str) -> String { + format!( + concat![ + "☀️ Welcome to Zenith!\n", + "To proceed with database creation, open the following link:\n\n", + " {redirect_uri}{session_id}\n\n", + "It needs to be done once and we will send you '.pgpass' file,\n", + "which will allow you to access or create ", + "databases without opening your web browser." + ], + redirect_uri = redirect_uri, + session_id = session_id, + ) +} diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs new file mode 100644 index 0000000000..62f195c3d2 --- /dev/null +++ b/proxy/src/cancellation.rs @@ -0,0 +1,91 @@ +use anyhow::{anyhow, Context}; +use hashbrown::HashMap; +use lazy_static::lazy_static; +use parking_lot::Mutex; +use std::net::SocketAddr; +use tokio::net::TcpStream; +use tokio_postgres::{CancelToken, NoTls}; +use zenith_utils::pq_proto::CancelKeyData; + +lazy_static! { + /// Enables serving CancelRequests. + static ref CANCEL_MAP: Mutex>> = Default::default(); +} + +/// This should've been a [`std::future::Future`], but +/// it's impossible to name a type of an unboxed future +/// (we'd need something like `#![feature(type_alias_impl_trait)]`). +#[derive(Clone)] +pub struct CancelClosure { + socket_addr: SocketAddr, + cancel_token: CancelToken, +} + +impl CancelClosure { + pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self { + Self { + socket_addr, + cancel_token, + } + } + + /// Cancels the query running on user's compute node. + pub async fn try_cancel_query(self) -> anyhow::Result<()> { + let socket = TcpStream::connect(self.socket_addr).await?; + self.cancel_token.cancel_query_raw(socket, NoTls).await?; + + Ok(()) + } +} + +/// Cancel a running query for the corresponding connection. +pub async fn cancel_session(key: CancelKeyData) -> anyhow::Result<()> { + let cancel_closure = CANCEL_MAP + .lock() + .get(&key) + .and_then(|x| x.clone()) + .with_context(|| format!("unknown session: {:?}", key))?; + + cancel_closure.try_cancel_query().await +} + +/// Helper for registering query cancellation tokens. +pub struct Session(CancelKeyData); + +impl Session { + /// Store the cancel token for the given session. + pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { + CANCEL_MAP.lock().insert(self.0, Some(cancel_closure)); + self.0 + } +} + +/// Run async action within an ephemeral session identified by [`CancelKeyData`]. +pub async fn with_session(f: F) -> anyhow::Result +where + F: FnOnce(Session) -> R, + R: std::future::Future>, +{ + // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't + // expose it and we don't want to do another roundtrip to query + // for it. The client will be able to notice that this is not the + // actual backend_pid, but backend_pid is not used for anything + // so it doesn't matter. + let key = rand::random(); + + // Random key collisions are unlikely to happen here, but they're still possible, + // which is why we have to take care not to rewrite an existing key. + CANCEL_MAP + .lock() + .try_insert(key, None) + .map_err(|_| anyhow!("session already exists: {:?}", key))?; + + // This will guarantee that the session gets dropped + // as soon as the future is finished. + scopeguard::defer! { + CANCEL_MAP.lock().remove(&key); + } + + let session = Session(key); + f(session).await +} diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs new file mode 100644 index 0000000000..7c294bd488 --- /dev/null +++ b/proxy/src/compute.rs @@ -0,0 +1,42 @@ +use anyhow::Context; +use serde::{Deserialize, Serialize}; +use std::net::{SocketAddr, ToSocketAddrs}; + +/// Compute node connection params. +#[derive(Serialize, Deserialize, Debug, Default)] +pub struct DatabaseInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + pub password: Option, +} + +impl DatabaseInfo { + pub fn socket_addr(&self) -> anyhow::Result { + let host_port = format!("{}:{}", self.host, self.port); + host_port + .to_socket_addrs() + .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? + .next() + .context("cannot resolve at least one SocketAddr") + } +} + +impl From for tokio_postgres::Config { + fn from(db_info: DatabaseInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + if let Some(password) = db_info.password { + config.password(password); + } + + config + } +} diff --git a/proxy/src/state.rs b/proxy/src/config.rs similarity index 77% rename from proxy/src/state.rs rename to proxy/src/config.rs index 04726a0756..a39980321b 100644 --- a/proxy/src/state.rs +++ b/proxy/src/config.rs @@ -1,10 +1,9 @@ -use crate::cplane_api::DatabaseInfo; use anyhow::{anyhow, ensure, Context}; use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig}; use std::net::SocketAddr; use std::sync::Arc; -pub type SslConfig = Arc; +pub type TlsConfig = Arc; pub struct ProxyConfig { /// main entrypoint for users to connect to @@ -24,26 +23,10 @@ pub struct ProxyConfig { /// control plane address where we would check auth. pub auth_endpoint: String, - pub ssl_config: Option, + pub tls_config: Option, } -pub type ProxyWaiters = crate::waiters::Waiters>; - -pub struct ProxyState { - pub conf: ProxyConfig, - pub waiters: ProxyWaiters, -} - -impl ProxyState { - pub fn new(conf: ProxyConfig) -> Self { - Self { - conf, - waiters: ProxyWaiters::default(), - } - } -} - -pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result { +pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result { let key = { let key_bytes = std::fs::read(key_path).context("SSL key file")?; let mut keys = pemfile::pkcs8_private_keys(&mut &key_bytes[..]) diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index aeb34e8b3b..187809717f 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -1,106 +1,87 @@ -use anyhow::{anyhow, bail, Context}; +use crate::auth::ClientCredentials; +use crate::compute::DatabaseInfo; +use crate::waiters::{Waiter, Waiters}; +use anyhow::{anyhow, bail}; +use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use std::net::{SocketAddr, ToSocketAddrs}; -use crate::state::ProxyWaiters; - -#[derive(Serialize, Deserialize, Debug, Default)] -pub struct DatabaseInfo { - pub host: String, - pub port: u16, - pub dbname: String, - pub user: String, - pub password: Option, +lazy_static! { + static ref CPLANE_WAITERS: Waiters> = Default::default(); } -#[derive(Serialize, Deserialize, Debug)] -#[serde(untagged)] -enum ProxyAuthResponse { - Ready { conn_info: DatabaseInfo }, - Error { error: String }, - NotReady { ready: bool }, // TODO: get rid of `ready` +/// Give caller an opportunity to wait for cplane's reply. +pub async fn with_waiter(psql_session_id: impl Into, f: F) -> anyhow::Result +where + F: FnOnce(Waiter<'static, Result>) -> R, + R: std::future::Future>, +{ + let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; + f(waiter).await } -impl DatabaseInfo { - pub fn socket_addr(&self) -> anyhow::Result { - let host_port = format!("{}:{}", self.host, self.port); - host_port - .to_socket_addrs() - .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? - .next() - .context("cannot resolve at least one SocketAddr") - } -} - -impl From for tokio_postgres::Config { - fn from(db_info: DatabaseInfo) -> Self { - let mut config = tokio_postgres::Config::new(); - - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - if let Some(password) = db_info.password { - config.password(password); - } - - config - } +pub fn notify(psql_session_id: &str, msg: Result) -> anyhow::Result<()> { + CPLANE_WAITERS.notify(psql_session_id, msg) } +/// Zenith console API wrapper. pub struct CPlaneApi<'a> { auth_endpoint: &'a str, - waiters: &'a ProxyWaiters, } impl<'a> CPlaneApi<'a> { - pub fn new(auth_endpoint: &'a str, waiters: &'a ProxyWaiters) -> Self { - Self { - auth_endpoint, - waiters, - } + pub fn new(auth_endpoint: &'a str) -> Self { + Self { auth_endpoint } } } impl CPlaneApi<'_> { - pub fn authenticate_proxy_request( + pub async fn authenticate_proxy_request( &self, - user: &str, - database: &str, + creds: ClientCredentials, md5_response: &[u8], salt: &[u8; 4], psql_session_id: &str, ) -> anyhow::Result { let mut url = reqwest::Url::parse(self.auth_endpoint)?; url.query_pairs_mut() - .append_pair("login", user) - .append_pair("database", database) + .append_pair("login", &creds.user) + .append_pair("database", &creds.dbname) .append_pair("md5response", std::str::from_utf8(md5_response)?) .append_pair("salt", &hex::encode(salt)) .append_pair("psql_session_id", psql_session_id); - let waiter = self.waiters.register(psql_session_id.to_owned()); + with_waiter(psql_session_id, |waiter| async { + println!("cplane request: {}", url); + // TODO: leverage `reqwest::Client` to reuse connections + let resp = reqwest::get(url).await?; + if !resp.status().is_success() { + bail!("Auth failed: {}", resp.status()) + } - println!("cplane request: {}", url); - let resp = reqwest::blocking::get(url)?; - if !resp.status().is_success() { - bail!("Auth failed: {}", resp.status()) - } + let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?; + println!("got auth info: #{:?}", auth_info); - let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text()?.as_str())?; - println!("got auth info: #{:?}", auth_info); - - use ProxyAuthResponse::*; - match auth_info { - Ready { conn_info } => Ok(conn_info), - Error { error } => bail!(error), - NotReady { .. } => waiter.wait()?.map_err(|e| anyhow!(e)), - } + use ProxyAuthResponse::*; + match auth_info { + Ready { conn_info } => Ok(conn_info), + Error { error } => bail!(error), + NotReady { .. } => waiter.await?.map_err(|e| anyhow!(e)), + } + }) + .await } } +// NOTE: the order of constructors is important. +// https://serde.rs/enum-representations.html#untagged +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +enum ProxyAuthResponse { + Ready { conn_info: DatabaseInfo }, + Error { error: String }, + NotReady { ready: bool }, // TODO: get rid of `ready` +} + #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 20863286ce..0b693d88dd 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -1,15 +1,30 @@ +use anyhow::anyhow; use hyper::{Body, Request, Response, StatusCode}; -use zenith_utils::http::RouterBuilder; - +use std::net::TcpListener; use zenith_utils::http::endpoint; use zenith_utils::http::error::ApiError; use zenith_utils::http::json::json_response; +use zenith_utils::http::{RouterBuilder, RouterService}; async fn status_handler(_: Request) -> Result, ApiError> { Ok(json_response(StatusCode::OK, "")?) } -pub fn make_router() -> RouterBuilder { +fn make_router() -> RouterBuilder { let router = endpoint::make_router(); router.get("/v1/status", status_handler) } + +pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> { + scopeguard::defer! { + println!("http has shut down"); + } + + let service = || RouterService::new(make_router().build()?); + + hyper::Server::from_tcp(http_listener)? + .serve(service().map_err(|e| anyhow!(e))?) + .await?; + + Ok(()) +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 8870bb8fec..a72c3b0b1e 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -5,21 +5,34 @@ /// (control plane API in our case) and can create new databases and accounts /// in somewhat transparent manner (again via communication with control plane API). /// -use anyhow::bail; +use anyhow::{bail, Context}; use clap::{App, Arg}; -use state::{ProxyConfig, ProxyState}; -use std::thread; -use zenith_utils::http::endpoint; -use zenith_utils::{tcp_listener, GIT_VERSION}; +use config::ProxyConfig; +use futures::FutureExt; +use std::future::Future; +use tokio::{net::TcpListener, task::JoinError}; +use zenith_utils::GIT_VERSION; +mod auth; +mod cancellation; +mod compute; +mod config; mod cplane_api; mod http; mod mgmt; mod proxy; -mod state; +mod stream; mod waiters; -fn main() -> anyhow::Result<()> { +/// Flattens Result> into Result. +async fn flatten_err( + f: impl Future, JoinError>>, +) -> anyhow::Result<()> { + f.map(|r| r.context("join error").and_then(|x| x)).await +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { zenith_metrics::set_common_metrics_prefix("zenith_proxy"); let arg_matches = App::new("Zenith proxy/router") .version(GIT_VERSION) @@ -79,63 +92,42 @@ fn main() -> anyhow::Result<()> { ) .get_matches(); - let ssl_config = match ( + let tls_config = match ( arg_matches.value_of("ssl-key"), arg_matches.value_of("ssl-cert"), ) { - (Some(key_path), Some(cert_path)) => { - Some(crate::state::configure_ssl(key_path, cert_path)?) - } + (Some(key_path), Some(cert_path)) => Some(config::configure_ssl(key_path, cert_path)?), (None, None) => None, _ => bail!("either both or neither ssl-key and ssl-cert must be specified"), }; - let config = ProxyConfig { + let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig { proxy_address: arg_matches.value_of("proxy").unwrap().parse()?, mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?, http_address: arg_matches.value_of("http").unwrap().parse()?, redirect_uri: arg_matches.value_of("uri").unwrap().parse()?, auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?, - ssl_config, - }; - let state: &ProxyState = Box::leak(Box::new(ProxyState::new(config))); + tls_config, + })); println!("Version: {}", GIT_VERSION); // Check that we can bind to address before further initialization - println!("Starting http on {}", state.conf.http_address); - let http_listener = tcp_listener::bind(state.conf.http_address)?; + println!("Starting http on {}", config.http_address); + let http_listener = TcpListener::bind(config.http_address).await?.into_std()?; - println!("Starting proxy on {}", state.conf.proxy_address); - let pageserver_listener = tcp_listener::bind(state.conf.proxy_address)?; + println!("Starting mgmt on {}", config.mgmt_address); + let mgmt_listener = TcpListener::bind(config.mgmt_address).await?.into_std()?; - println!("Starting mgmt on {}", state.conf.mgmt_address); - let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?; + println!("Starting proxy on {}", config.proxy_address); + let proxy_listener = TcpListener::bind(config.proxy_address).await?; - let threads = [ - thread::Builder::new() - .name("Http thread".into()) - .spawn(move || { - let router = http::make_router(); - endpoint::serve_thread_main( - router, - http_listener, - std::future::pending(), // never shut down - ) - })?, - // Spawn a thread to listen for connections. It will spawn further threads - // for each connection. - thread::Builder::new() - .name("Listener thread".into()) - .spawn(move || proxy::thread_main(state, pageserver_listener))?, - thread::Builder::new() - .name("Mgmt thread".into()) - .spawn(move || mgmt::thread_main(state, mgmt_listener))?, - ]; + let http = tokio::spawn(http::thread_main(http_listener)); + let proxy = tokio::spawn(proxy::thread_main(config, proxy_listener)); + let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)); - for t in threads { - t.join().unwrap()?; - } + let tasks = [flatten_err(http), flatten_err(proxy), flatten_err(mgmt)]; + let _: Vec<()> = futures::future::try_join_all(tasks).await?; Ok(()) } diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 9d8dc5130f..55b49b441f 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -1,44 +1,49 @@ +use crate::{compute::DatabaseInfo, cplane_api}; +use anyhow::Context; +use serde::Deserialize; use std::{ net::{TcpListener, TcpStream}, thread, }; - -use serde::Deserialize; use zenith_utils::{ postgres_backend::{self, AuthType, PostgresBackend}, pq_proto::{BeMessage, SINGLE_COL_ROWDESC}, }; -use crate::{cplane_api::DatabaseInfo, ProxyState}; - /// /// Main proxy listener loop. /// /// Listens for connections, and launches a new handler thread for each. /// -pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> { +pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> { + scopeguard::defer! { + println!("mgmt has shut down"); + } + + listener + .set_nonblocking(false) + .context("failed to set listener to blocking")?; loop { - let (socket, peer_addr) = listener.accept()?; + let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?; println!("accepted connection from {}", peer_addr); - socket.set_nodelay(true).unwrap(); + socket + .set_nodelay(true) + .context("failed to set client socket option")?; thread::spawn(move || { - if let Err(err) = handle_connection(state, socket) { + if let Err(err) = handle_connection(socket) { println!("error: {}", err); } }); } } -fn handle_connection(state: &ProxyState, socket: TcpStream) -> anyhow::Result<()> { - let mut conn_handler = MgmtHandler { state }; +fn handle_connection(socket: TcpStream) -> anyhow::Result<()> { let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?; - pgbackend.run(&mut conn_handler) + pgbackend.run(&mut MgmtHandler) } -struct MgmtHandler<'a> { - state: &'a ProxyState, -} +struct MgmtHandler; /// Serialized examples: // { @@ -74,13 +79,13 @@ enum PsqlSessionResult { Failure(String), } -impl postgres_backend::Handler for MgmtHandler<'_> { +impl postgres_backend::Handler for MgmtHandler { fn process_query( &mut self, pgb: &mut PostgresBackend, query_string: &str, ) -> anyhow::Result<()> { - let res = try_process_query(self, pgb, query_string); + let res = try_process_query(pgb, query_string); // intercept and log error message if res.is_err() { println!("Mgmt query failed: #{:?}", res); @@ -89,11 +94,7 @@ impl postgres_backend::Handler for MgmtHandler<'_> { } } -fn try_process_query( - mgmt: &mut MgmtHandler, - pgb: &mut PostgresBackend, - query_string: &str, -) -> anyhow::Result<()> { +fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::Result<()> { println!("Got mgmt query: '{}'", query_string); let resp: PsqlSessionResponse = serde_json::from_str(query_string)?; @@ -104,7 +105,7 @@ fn try_process_query( Failure(message) => Err(message), }; - match mgmt.state.waiters.notify(&resp.session_id, msg) { + match cplane_api::notify(&resp.session_id, msg) { Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 85f577a6c2..1bf48f89cc 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,389 +1,322 @@ -use crate::cplane_api::{CPlaneApi, DatabaseInfo}; -use crate::ProxyState; -use anyhow::{anyhow, bail, Context}; +use crate::auth; +use crate::cancellation::{self, CancelClosure}; +use crate::compute::DatabaseInfo; +use crate::config::{ProxyConfig, TlsConfig}; +use crate::stream::{MetricsStream, PqStream, Stream}; +use anyhow::{bail, Context}; use lazy_static::lazy_static; -use rand::prelude::StdRng; -use rand::{Rng, SeedableRng}; -use std::cell::Cell; -use std::collections::HashMap; -use std::net::{SocketAddr, TcpStream}; -use std::sync::Mutex; -use std::{io, thread}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use tokio_postgres::NoTls; use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter}; -use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream}; -use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *}; -use zenith_utils::sock_split::{ReadStream, WriteStream}; - -struct CancelClosure { - socket_addr: SocketAddr, - cancel_token: tokio_postgres::CancelToken, -} - -impl CancelClosure { - async fn try_cancel_query(&self) { - if let Ok(socket) = tokio::net::TcpStream::connect(self.socket_addr).await { - // NOTE ignoring the result because: - // 1. This is a best effort attempt, the database doesn't have to listen - // 2. Being opaque about errors here helps avoid leaking info to unauthenticated user - let _ = self.cancel_token.cancel_query_raw(socket, NoTls).await; - } - } -} +use zenith_utils::pq_proto::{BeMessage as Be, *}; lazy_static! { - // Enables serving CancelRequests - static ref CANCEL_MAP: Mutex> = Mutex::new(HashMap::new()); - - // Metrics static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!( new_common_metric_name("num_connections_accepted"), "Number of TCP client connections accepted." - ).unwrap(); + ) + .unwrap(); static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!( new_common_metric_name("num_connections_closed"), "Number of TCP client connections closed." - ).unwrap(); - static ref NUM_CONNECTIONS_FAILED_COUNTER: IntCounter = register_int_counter!( - new_common_metric_name("num_connections_failed"), - "Number of TCP client connections that closed due to error." - ).unwrap(); + ) + .unwrap(); static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!( new_common_metric_name("num_bytes_proxied"), "Number of bytes sent/received between any client and backend." - ).unwrap(); -} - -thread_local! { - // Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop. - static THREAD_CANCEL_KEY_DATA: Cell> = Cell::new(None); -} - -/// -/// Main proxy listener loop. -/// -/// Listens for connections, and launches a new handler thread for each. -/// -pub fn thread_main( - state: &'static ProxyState, - listener: std::net::TcpListener, -) -> anyhow::Result<()> { - loop { - let (socket, peer_addr) = listener.accept()?; - println!("accepted connection from {}", peer_addr); - NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); - socket.set_nodelay(true).unwrap(); - - // TODO Use a threadpool instead. Maybe use tokio's threadpool by - // spawning a future into its runtime. Tokio's JoinError should - // allow us to handle cleanup properly even if the future panics. - thread::Builder::new() - .name("Proxy thread".into()) - .spawn(move || { - if let Err(err) = proxy_conn_main(state, socket) { - NUM_CONNECTIONS_FAILED_COUNTER.inc(); - println!("error: {}", err); - } - - // Clean up CANCEL_MAP. - NUM_CONNECTIONS_CLOSED_COUNTER.inc(); - THREAD_CANCEL_KEY_DATA.with(|cell| { - if let Some(cancel_key_data) = cell.get() { - CANCEL_MAP.lock().unwrap().remove(&cancel_key_data); - }; - }); - })?; - } -} - -// TODO: clean up fields -struct ProxyConnection { - state: &'static ProxyState, - psql_session_id: String, - pgb: PostgresBackend, -} - -pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> { - let conn = ProxyConnection { - state, - psql_session_id: hex::encode(rand::random::<[u8; 8]>()), - pgb: PostgresBackend::new( - socket, - postgres_backend::AuthType::MD5, - state.conf.ssl_config.clone(), - false, - )?, - }; - - let (client, server) = match conn.handle_client()? { - Some(x) => x, - None => return Ok(()), - }; - - let server = zenith_utils::sock_split::BidiStream::from_tcp(server); - - let client = match client { - Stream::Bidirectional(bidi_stream) => bidi_stream, - _ => panic!("invalid stream type"), - }; - - proxy(client.split(), server.split()) -} - -impl ProxyConnection { - /// Returns Ok(None) when connection was successfully closed. - fn handle_client(mut self) -> anyhow::Result> { - let mut authenticate = || { - let (username, dbname) = match self.handle_startup()? { - Some(x) => x, - None => return Ok(None), - }; - - // Both scenarios here should end up producing database credentials - if username.ends_with("@zenith") { - self.handle_existing_user(&username, &dbname).map(Some) - } else { - self.handle_new_user().map(Some) - } - }; - - let conn = match authenticate() { - Ok(Some(db_info)) => connect_to_db(db_info), - Ok(None) => return Ok(None), - Err(e) => { - // Report the error to the client - self.pgb.write_message(&Be::ErrorResponse(&e.to_string()))?; - bail!("failed to handle client: {:?}", e); - } - }; - - // We'll get rid of this once migration to async is complete - let (pg_version, db_stream) = { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?; - self.pgb - .write_message(&BeMessage::BackendKeyData(cancel_key_data))?; - let stream = stream.into_std()?; - stream.set_nonblocking(false)?; - - (pg_version, stream) - }; - - // Let the client send new requests - self.pgb - .write_message_noflush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ServerVersion(&pg_version), - ))? - .write_message(&Be::ReadyForQuery)?; - - Ok(Some((self.pgb.into_stream(), db_stream))) - } - - /// Returns Ok(None) when connection was successfully closed. - fn handle_startup(&mut self) -> anyhow::Result> { - let have_tls = self.pgb.tls_config.is_some(); - let mut encrypted = false; - - loop { - let msg = match self.pgb.read_message()? { - Some(Fe::StartupPacket(msg)) => msg, - None => bail!("connection is lost"), - bad => bail!("unexpected message type: {:?}", bad), - }; - println!("got message: {:?}", msg); - - match msg { - FeStartupPacket::GssEncRequest => { - self.pgb.write_message(&Be::EncryptionResponse(false))?; - } - FeStartupPacket::SslRequest => { - self.pgb.write_message(&Be::EncryptionResponse(have_tls))?; - if have_tls { - self.pgb.start_tls()?; - encrypted = true; - } - } - FeStartupPacket::StartupMessage { mut params, .. } => { - if have_tls && !encrypted { - bail!("must connect with TLS"); - } - - let mut get_param = |key| { - params - .remove(key) - .with_context(|| format!("{} is missing in startup packet", key)) - }; - - return Ok(Some((get_param("user")?, get_param("database")?))); - } - FeStartupPacket::CancelRequest(cancel_key_data) => { - if let Some(cancel_closure) = CANCEL_MAP.lock().unwrap().get(&cancel_key_data) { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - runtime.block_on(cancel_closure.try_cancel_query()); - } - return Ok(None); - } - } - } - } - - fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result { - let md5_salt = rand::random::<[u8; 4]>(); - - // Ask password - self.pgb - .write_message(&Be::AuthenticationMD5Password(&md5_salt))?; - self.pgb.state = ProtoState::Authentication; // XXX - - // Check password - let msg = match self.pgb.read_message()? { - Some(Fe::PasswordMessage(msg)) => msg, - None => bail!("connection is lost"), - bad => bail!("unexpected message type: {:?}", bad), - }; - println!("got message: {:?}", msg); - - let (_trailing_null, md5_response) = msg - .split_last() - .ok_or_else(|| anyhow!("unexpected password message"))?; - - let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters); - let db_info = cplane.authenticate_proxy_request( - user, - db, - md5_response, - &md5_salt, - &self.psql_session_id, - )?; - - self.pgb - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())?; - - Ok(db_info) - } - - fn handle_new_user(&mut self) -> anyhow::Result { - let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id); - - // First, register this session - let waiter = self.state.waiters.register(self.psql_session_id.clone()); - - // Give user a URL to spawn a new database - self.pgb - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? - .write_message(&Be::NoticeResponse(greeting))?; - - // Wait for web console response - let db_info = waiter.wait()?.map_err(|e| anyhow!(e))?; - - self.pgb - .write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?; - - Ok(db_info) - } -} - -fn hello_message(redirect_uri: &str, session_id: &str) -> String { - format!( - concat![ - "☀️ Welcome to Zenith!\n", - "To proceed with database creation, open the following link:\n\n", - " {redirect_uri}{session_id}\n\n", - "It needs to be done once and we will send you '.pgpass' file,\n", - "which will allow you to access or create ", - "databases without opening your web browser." - ], - redirect_uri = redirect_uri, - session_id = session_id, ) + .unwrap(); } -/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message -async fn connect_to_db( - db_info: DatabaseInfo, -) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> { - // Make raw connection. When connect_raw finishes we've received ReadyForQuery. - let socket_addr = db_info.socket_addr()?; - let mut socket = tokio::net::TcpStream::connect(socket_addr).await?; - let config = tokio_postgres::Config::from(db_info); - // NOTE We effectively ignore some ParameterStatus and NoticeResponse - // messages here. Not sure if that could break something. - let (client, conn) = config.connect_raw(&mut socket, NoTls).await?; - - // Save info for potentially cancelling the query later - let mut rng = StdRng::from_entropy(); - let cancel_key_data = CancelKeyData { - // HACK We'd rather get the real backend_pid but tokio_postgres doesn't - // expose it and we don't want to do another roundtrip to query - // for it. The client will be able to notice that this is not the - // actual backend_pid, but backend_pid is not used for anything - // so it doesn't matter. - backend_pid: rng.gen(), - cancel_key: rng.gen(), - }; - let cancel_closure = CancelClosure { - socket_addr, - cancel_token: client.cancel_token(), - }; - CANCEL_MAP - .lock() - .unwrap() - .insert(cancel_key_data, cancel_closure); - THREAD_CANCEL_KEY_DATA.with(|cell| { - let prev_value = cell.replace(Some(cancel_key_data)); - assert!( - prev_value.is_none(), - "THREAD_CANCEL_KEY_DATA was already set" - ); - }); - - let version = conn.parameter("server_version").unwrap(); - Ok((version.into(), socket, cancel_key_data)) +async fn log_error(future: F) -> F::Output +where + F: std::future::Future>, +{ + future.await.map_err(|err| { + println!("error: {}", err); + err + }) } -/// Concurrently proxy both directions of the client and server connections -fn proxy( - (client_read, client_write): (ReadStream, WriteStream), - (server_read, server_write): (ReadStream, WriteStream), +pub async fn thread_main( + config: &'static ProxyConfig, + listener: tokio::net::TcpListener, ) -> anyhow::Result<()> { - fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result { - /// FlushWriter will make sure that every message is sent as soon as possible - struct FlushWriter(W); - - impl io::Write for FlushWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - // `std::io::copy` is guaranteed to exit if we return an error, - // so we can afford to lose `res` in case `flush` fails - let res = self.0.write(buf); - if let Ok(count) = res { - NUM_BYTES_PROXIED_COUNTER.inc_by(count as u64); - self.flush()?; - } - res - } - - fn flush(&mut self) -> io::Result<()> { - self.0.flush() - } - } - - let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer)); - writer.shutdown(std::net::Shutdown::Both)?; - res + scopeguard::defer! { + println!("proxy has shut down"); } - let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write)); + loop { + let (socket, peer_addr) = listener.accept().await?; + println!("accepted connection from {}", peer_addr); - do_proxy(server_read, client_write)?; - client_to_server_jh.join().unwrap()?; + tokio::spawn(log_error(async { + socket + .set_nodelay(true) + .context("failed to set socket option")?; + + handle_client(config, socket).await + })); + } +} + +async fn handle_client( + config: &ProxyConfig, + stream: impl AsyncRead + AsyncWrite + Unpin, +) -> anyhow::Result<()> { + // The `closed` counter will increase when this future is destroyed. + NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); + scopeguard::defer! { + NUM_CONNECTIONS_CLOSED_COUNTER.inc(); + } + + let tls = config.tls_config.clone(); + if let Some((stream, creds)) = handshake(stream, tls).await? { + cancellation::with_session(|session| async { + connect_client_to_db(config, stream, creds, session).await + }) + .await?; + } Ok(()) } + +/// Handle a connection from one client. +/// For better testing experience, `stream` can be +/// any object satisfying the traits. +async fn handshake( + stream: S, + mut tls: Option, +) -> anyhow::Result>, auth::ClientCredentials)>> { + // Client may try upgrading to each protocol only once + let (mut tried_ssl, mut tried_gss) = (false, false); + + let mut stream = PqStream::new(Stream::from_raw(stream)); + loop { + let msg = stream.read_startup_packet().await?; + println!("got message: {:?}", msg); + + use FeStartupPacket::*; + match msg { + SslRequest => match stream.get_ref() { + Stream::Raw { .. } if !tried_ssl => { + tried_ssl = true; + + // We can't perform TLS handshake without a config + let enc = tls.is_some(); + stream.write_message(&Be::EncryptionResponse(enc)).await?; + + if let Some(tls) = tls.take() { + // Upgrade raw stream into a secure TLS-backed stream. + // NOTE: We've consumed `tls`; this fact will be used later. + stream = PqStream::new(stream.into_inner().upgrade(tls).await?); + } + } + _ => bail!("protocol violation"), + }, + GssEncRequest => match stream.get_ref() { + Stream::Raw { .. } if !tried_gss => { + tried_gss = true; + + // Currently, we don't support GSSAPI + stream.write_message(&Be::EncryptionResponse(false)).await?; + } + _ => bail!("protocol violation"), + }, + StartupMessage { params, .. } => { + // Check that the config has been consumed during upgrade + // OR we didn't provide it at all (for dev purposes). + if tls.is_some() { + let msg = "connection is insecure (try using `sslmode=require`)"; + stream.write_message(&Be::ErrorResponse(msg)).await?; + bail!(msg); + } + + break Ok(Some((stream, params.try_into()?))); + } + CancelRequest(cancel_key_data) => { + cancellation::cancel_session(cancel_key_data).await?; + + break Ok(None); + } + } + } +} + +async fn connect_client_to_db( + config: &ProxyConfig, + mut client: PqStream, + creds: auth::ClientCredentials, + session: cancellation::Session, +) -> anyhow::Result<()> { + let db_info = creds.authenticate(config, &mut client).await?; + let (db, version, cancel_closure) = connect_to_db(db_info).await?; + let cancel_key_data = session.enable_cancellation(cancel_closure); + + client + .write_message_noflush(&BeMessage::ParameterStatus( + BeParameterStatusMessage::ServerVersion(&version), + ))? + .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? + .write_message(&BeMessage::ReadyForQuery) + .await?; + + // This function will be called for writes to either direction. + fn inc_proxied(cnt: usize) { + // Consider inventing something more sophisticated + // if this ever becomes a bottleneck (cacheline bouncing). + NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64); + } + + let mut db = MetricsStream::new(db, inc_proxied); + let mut client = MetricsStream::new(client.into_inner(), inc_proxied); + let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; + + Ok(()) +} + +/// Connect to a corresponding compute node. +async fn connect_to_db( + db_info: DatabaseInfo, +) -> anyhow::Result<(TcpStream, String, CancelClosure)> { + // TODO: establish a secure connection to the DB + let socket_addr = db_info.socket_addr()?; + let mut socket = TcpStream::connect(socket_addr).await?; + + let (client, conn) = tokio_postgres::Config::from(db_info) + .connect_raw(&mut socket, NoTls) + .await?; + + let version = conn + .parameter("server_version") + .context("failed to fetch postgres server version")? + .into(); + + let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token()); + + Ok((socket, version, cancel_closure)) +} + +#[cfg(test)] +mod tests { + use super::*; + + use tokio::io::DuplexStream; + use tokio_postgres::config::SslMode; + use tokio_postgres::tls::MakeTlsConnect; + use tokio_postgres_rustls::MakeRustlsConnect; + + async fn dummy_proxy( + client: impl AsyncRead + AsyncWrite + Unpin, + tls: Option, + ) -> anyhow::Result<()> { + // TODO: add some infra + tests for credentials + let (mut stream, _creds) = handshake(client, tls).await?.context("no stream")?; + + stream + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message(&BeMessage::ReadyForQuery) + .await?; + + Ok(()) + } + + fn generate_certs( + hostname: &str, + ) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> { + let ca = rcgen::Certificate::from_params({ + let mut params = rcgen::CertificateParams::default(); + params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + params + })?; + + let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?; + Ok(( + rustls::Certificate(ca.serialize_der()?), + rustls::Certificate(cert.serialize_der_with_signer(&ca)?), + rustls::PrivateKey(cert.serialize_private_key_der()), + )) + } + + #[tokio::test] + async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let server_config = { + let (_ca, cert, key) = generate_certs("localhost")?; + + let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(vec![cert], key)?; + config + }; + + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); + + tokio_postgres::Config::new() + .user("john_doe") + .dbname("earth") + .ssl_mode(SslMode::Disable) + .connect_raw(server, NoTls) + .await + .err() // -> Option + .context("client shouldn't be able to connect")?; + + proxy + .await? + .err() // -> Option + .context("server shouldn't accept client")?; + + Ok(()) + } + + #[tokio::test] + async fn handshake_tls() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (ca, cert, key) = generate_certs("localhost")?; + + let server_config = { + let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(vec![cert], key)?; + config + }; + + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); + + let client_config = { + let mut config = rustls::ClientConfig::new(); + config.root_store.add(&ca)?; + config + }; + + let mut mk = MakeRustlsConnect::new(client_config); + let tls = MakeTlsConnect::::make_tls_connect(&mut mk, "localhost")?; + + let (_client, _conn) = tokio_postgres::Config::new() + .user("john_doe") + .dbname("earth") + .ssl_mode(SslMode::Require) + .connect_raw(server, tls) + .await?; + + proxy.await? + } + + #[tokio::test] + async fn handshake_raw() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let proxy = tokio::spawn(dummy_proxy(client, None)); + + let (_client, _conn) = tokio_postgres::Config::new() + .user("john_doe") + .dbname("earth") + .ssl_mode(SslMode::Prefer) + .connect_raw(server, NoTls) + .await?; + + proxy.await? + } +} diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs new file mode 100644 index 0000000000..8fd5bef388 --- /dev/null +++ b/proxy/src/stream.rs @@ -0,0 +1,230 @@ +use anyhow::Context; +use bytes::BytesMut; +use pin_project_lite::pin_project; +use rustls::ServerConfig; +use std::pin::Pin; +use std::sync::Arc; +use std::{io, task}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio_rustls::server::TlsStream; +use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket}; + +pin_project! { + /// Stream wrapper which implements libpq's protocol. + /// NOTE: This object deliberately doesn't implement [`AsyncRead`] + /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying + /// to pass random malformed bytes through the connection). + pub struct PqStream { + #[pin] + stream: S, + buffer: BytesMut, + } +} + +impl PqStream { + /// Construct a new libpq protocol wrapper. + pub fn new(stream: S) -> Self { + Self { + stream, + buffer: Default::default(), + } + } + + /// Extract the underlying stream. + pub fn into_inner(self) -> S { + self.stream + } + + /// Get a reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.stream + } +} + +impl PqStream { + /// Receive [`FeStartupPacket`], which is a first packet sent by a client. + pub async fn read_startup_packet(&mut self) -> anyhow::Result { + match FeStartupPacket::read_fut(&mut self.stream).await? { + Some(FeMessage::StartupPacket(packet)) => Ok(packet), + None => anyhow::bail!("connection is lost"), + other => anyhow::bail!("bad message type: {:?}", other), + } + } + + pub async fn read_message(&mut self) -> anyhow::Result { + FeMessage::read_fut(&mut self.stream) + .await? + .context("connection is lost") + } +} + +impl PqStream { + /// Write the message into an internal buffer, but don't flush the underlying stream. + pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> { + BeMessage::write(&mut self.buffer, message)?; + Ok(self) + } + + /// Write the message into an internal buffer and flush it. + pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> { + self.write_message_noflush(message)?; + self.flush().await?; + Ok(self) + } + + /// Flush the output buffer into the underlying stream. + pub async fn flush(&mut self) -> io::Result<&mut Self> { + self.stream.write_all(&self.buffer).await?; + self.buffer.clear(); + self.stream.flush().await?; + Ok(self) + } +} + +pin_project! { + /// Wrapper for upgrading raw streams into secure streams. + /// NOTE: it should be possible to decompose this object as necessary. + #[project = StreamProj] + pub enum Stream { + /// We always begin with a raw stream, + /// which may then be upgraded into a secure stream. + Raw { #[pin] raw: S }, + /// We box [`TlsStream`] since it can be quite large. + Tls { #[pin] tls: Box> }, + } +} + +impl Stream { + /// Construct a new instance from a raw stream. + pub fn from_raw(raw: S) -> Self { + Self::Raw { raw } + } +} + +impl Stream { + /// If possible, upgrade raw stream into a secure TLS-based stream. + pub async fn upgrade(self, cfg: Arc) -> anyhow::Result { + match self { + Stream::Raw { raw } => { + let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?); + Ok(Stream::Tls { tls }) + } + Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"), + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + use StreamProj::*; + match self.project() { + Raw { raw } => raw.poll_read(context, buf), + Tls { tls } => tls.poll_read(context, buf), + } + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + buf: &[u8], + ) -> task::Poll> { + use StreamProj::*; + match self.project() { + Raw { raw } => raw.poll_write(context, buf), + Tls { tls } => tls.poll_write(context, buf), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + ) -> task::Poll> { + use StreamProj::*; + match self.project() { + Raw { raw } => raw.poll_flush(context), + Tls { tls } => tls.poll_flush(context), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + ) -> task::Poll> { + use StreamProj::*; + match self.project() { + Raw { raw } => raw.poll_shutdown(context), + Tls { tls } => tls.poll_shutdown(context), + } + } +} + +pin_project! { + /// This stream tracks all writes and calls user provided + /// callback when the underlying stream is flushed. + pub struct MetricsStream { + #[pin] + stream: S, + write_count: usize, + inc_write_count: W, + } +} + +impl MetricsStream { + pub fn new(stream: S, inc_write_count: W) -> Self { + Self { + stream, + write_count: 0, + inc_write_count, + } + } +} + +impl AsyncRead for MetricsStream { + fn poll_read( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + self.project().stream.poll_read(context, buf) + } +} + +impl AsyncWrite for MetricsStream { + fn poll_write( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + buf: &[u8], + ) -> task::Poll> { + let this = self.project(); + this.stream.poll_write(context, buf).map_ok(|cnt| { + // Increment the write count. + *this.write_count += cnt; + cnt + }) + } + + fn poll_flush( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + ) -> task::Poll> { + let this = self.project(); + this.stream.poll_flush(context).map_ok(|()| { + // Call the user provided callback and reset the write count. + (this.inc_write_count)(*this.write_count); + *this.write_count = 0; + }) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + context: &mut task::Context<'_>, + ) -> task::Poll> { + self.project().stream.poll_shutdown(context) + } +} diff --git a/proxy/src/waiters.rs b/proxy/src/waiters.rs index 7baa0b102f..9fda3ed94f 100644 --- a/proxy/src/waiters.rs +++ b/proxy/src/waiters.rs @@ -1,8 +1,12 @@ -use anyhow::Context; -use std::collections::HashMap; -use std::sync::{mpsc, Mutex}; +use anyhow::{anyhow, Context}; +use hashbrown::HashMap; +use parking_lot::Mutex; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task; +use tokio::sync::oneshot; -pub struct Waiters(pub(self) Mutex>>); +pub struct Waiters(pub(self) Mutex>>); impl Default for Waiters { fn default() -> Self { @@ -11,48 +15,86 @@ impl Default for Waiters { } impl Waiters { - pub fn register(&self, key: String) -> Waiter { - let (tx, rx) = mpsc::channel(); + pub fn register(&self, key: String) -> anyhow::Result> { + let (tx, rx) = oneshot::channel(); - // TODO: use `try_insert` (unstable) - let prev = self.0.lock().unwrap().insert(key.clone(), tx); - assert!(matches!(prev, None)); // assert_matches! is nightly-only + self.0 + .lock() + .try_insert(key.clone(), tx) + .map_err(|_| anyhow!("waiter already registered"))?; - Waiter { + Ok(Waiter { receiver: rx, - registry: self, - key, - } + guard: DropKey { + registry: self, + key, + }, + }) } pub fn notify(&self, key: &str, value: T) -> anyhow::Result<()> where - T: Send + Sync + 'static, + T: Send + Sync, { let tx = self .0 .lock() - .unwrap() .remove(key) .with_context(|| format!("key {} not found", key))?; - tx.send(value).context("channel hangup") + + tx.send(value).map_err(|_| anyhow!("waiter channel hangup")) } } -pub struct Waiter<'a, T> { - receiver: mpsc::Receiver, - registry: &'a Waiters, +struct DropKey<'a, T> { key: String, + registry: &'a Waiters, } -impl Waiter<'_, T> { - pub fn wait(self) -> anyhow::Result { - self.receiver.recv().context("channel hangup") - } -} - -impl Drop for Waiter<'_, T> { +impl<'a, T> Drop for DropKey<'a, T> { fn drop(&mut self) { - self.registry.0.lock().unwrap().remove(&self.key); + self.registry.0.lock().remove(&self.key); + } +} + +pin_project! { + pub struct Waiter<'a, T> { + #[pin] + receiver: oneshot::Receiver, + guard: DropKey<'a, T>, + } +} + +impl std::future::Future for Waiter<'_, T> { + type Output = anyhow::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + self.project() + .receiver + .poll(cx) + .map_err(|_| anyhow!("channel hangup")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[tokio::test] + async fn test_waiter() -> anyhow::Result<()> { + let waiters = Arc::new(Waiters::default()); + + let key = "Key"; + let waiter = waiters.register(key.to_owned())?; + + let waiters = Arc::clone(&waiters); + let notifier = tokio::spawn(async move { + waiters.notify(key, Default::default())?; + Ok(()) + }); + + let () = waiter.await?; + notifier.await? } } diff --git a/zenith_utils/src/http/mod.rs b/zenith_utils/src/http/mod.rs index ef842fd2ff..0bb53ef51d 100644 --- a/zenith_utils/src/http/mod.rs +++ b/zenith_utils/src/http/mod.rs @@ -5,4 +5,4 @@ pub mod request; /// Current fast way to apply simple http routing in various Zenith binaries. /// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed. -pub use routerify::{ext::RequestExt, RouterBuilder}; +pub use routerify::{ext::RequestExt, RouterBuilder, RouterService}; diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index 89be25cb54..355b38fc95 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -57,6 +57,16 @@ pub struct CancelKeyData { pub cancel_key: i32, } +use rand::distributions::{Distribution, Standard}; +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> CancelKeyData { + CancelKeyData { + backend_pid: rng.gen(), + cancel_key: rng.gen(), + } + } +} + #[derive(Debug)] pub struct FeQueryMessage { pub body: Bytes,