diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 8aa1923f61..f0412189cc 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -4,12 +4,13 @@ /// This allows connecting to pods/services running in the same Kubernetes cluster from /// the outside. Similar to an ingress controller for HTTPS. use std::{net::SocketAddr, sync::Arc}; -use tokio::{net::TcpListener, io::AsyncWriteExt}; +use tokio::net::TcpListener; -use anyhow::{bail, ensure, Context}; +use anyhow::{anyhow, bail, ensure, Context}; use clap::{self, Arg}; use futures::TryFutureExt; -use proxy::{cancellation::CancelMap, auth::{AuthFlow, self}, compute::ConnCfg, console::messages::MetricsAuxInfo}; +use proxy::console::messages::MetricsAuxInfo; +use proxy::stream::{PqStream, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use utils::{project_git_version, sentry_init::init_sentry}; @@ -145,7 +146,6 @@ async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let mut connections = tokio::task::JoinSet::new(); - let cancel_map = Arc::new(CancelMap::default()); loop { tokio::select! { @@ -154,7 +154,6 @@ async fn task_main( info!("accepted postgres client connection from {peer_addr}"); let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); let tls_config = Arc::clone(&tls_config); let dest_suffix = Arc::clone(&dest_suffix); @@ -166,7 +165,7 @@ async fn task_main( .set_nodelay(true) .context("failed to set socket option")?; - handle_client(dest_suffix, dest_port, tls_config, &cancel_map, session_id, socket).await + handle_client(dest_suffix, dest_port, tls_config, session_id, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -191,54 +190,66 @@ async fn task_main( Ok(()) } +const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; + +async fn ssl_handshake( + raw_stream: S, + tls_config: Arc, +) -> anyhow::Result> { + let mut stream = PqStream::new(Stream::from_raw(raw_stream)); + + let msg = stream.read_startup_packet().await?; + info!("received {msg:?}"); + use pq_proto::FeStartupPacket::*; + + match msg { + SslRequest => { + stream + .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) + .await?; + // Upgrade raw stream into a secure TLS-backed stream. + // NOTE: We've consumed `tls`; this fact will be used later. + + let (raw, read_buf) = stream.into_inner(); + // TODO: Normally, client doesn't send any data before + // server says TLS handshake is ok and read_buf is empy. + // However, you could imagine pipelining of postgres + // SSLRequest + TLS ClientHello in one hunk similar to + // pipelining in our node js driver. We should probably + // support that by chaining read_buf with the stream. + if !read_buf.is_empty() { + bail!("data is sent before server replied with EncryptionResponse"); + } + Ok(raw.upgrade(tls_config).await?) + } + _ => stream.throw_error_str(ERR_INSECURE_CONNECTION).await?, + } +} + #[tracing::instrument(fields(session_id = ?session_id), skip_all)] async fn handle_client( dest_suffix: Arc, dest_port: u16, - tls: Arc, - cancel_map: &CancelMap, + tls_config: Arc, session_id: uuid::Uuid, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let do_handshake = proxy::proxy::handshake(stream, Some(tls), cancel_map); - let (mut stream, params) = match do_handshake.await? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request - }; - - let password = AuthFlow::new(&mut stream) - .begin(auth::CleartextPassword) - .await? - .authenticate() - .await?; - - let mut conn_cfg = ConnCfg::new(); - conn_cfg.set_startup_params(¶ms); - conn_cfg.password(password); + let tls_stream = ssl_handshake(stream, tls_config).await?; // cut off first part of the sni domain - let sni = stream.get_ref().sni_hostname().unwrap(); + let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; let dest = sni - .split_once('.').context("invalid sni")?.0 + .split_once('.') + .context("invalid sni")? + .0 .replace("--", "."); - let destination = format!("{}.{}", dest, dest_suffix); + let destination = format!("{}.{}:{}", dest, dest_suffix, dest_port); info!("destination: {}:{}", destination, dest_port); - conn_cfg.host(destination.as_str()); - conn_cfg.port(dest_port); + let client = tokio::net::TcpStream::connect(destination).await?; - let mut conn = conn_cfg.connect() - .or_else(|e| stream.throw_error(e)) - .await?; - - cancel_map.with_session(|session| async { - proxy::proxy::prepare_client_connection(&conn, false, session, &mut stream).await?; - let (stream, read_buf) = stream.into_inner(); - conn.stream.write_all(&read_buf).await?; - let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::proxy_pass(stream, conn.stream, &metrics_aux).await - }) - .await + let metrics_aux: MetricsAuxInfo = Default::default(); + proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await }