diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 7e1f6ea940..17660e5b31 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -4,17 +4,17 @@ /// This allows connecting to pods/services running in the same Kubernetes cluster from /// the outside. Similar to an ingress controller for HTTPS. use std::{net::SocketAddr, sync::Arc}; -use tokio::{io::AsyncWriteExt, net::TcpListener}; -use anyhow::{bail, ensure, Context}; +use tokio::net::TcpListener; +// use tokio::net::TcpListener; + +use anyhow::{anyhow, bail, ensure, Context}; use clap::{self, Arg}; use futures::TryFutureExt; -use proxy::{ - auth::{self, AuthFlow}, - cancellation::CancelMap, - compute::ConnCfg, - console::messages::MetricsAuxInfo, -}; +use proxy::console::messages::MetricsAuxInfo; +// use proxy::console::messages::MetricsAuxInfo; +use proxy::stream::{PqStream, Stream}; + use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use utils::{project_git_version, sentry_init::init_sentry}; @@ -140,7 +140,6 @@ async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let mut connections = tokio::task::JoinSet::new(); - let cancel_map = Arc::new(CancelMap::default()); loop { tokio::select! { @@ -149,7 +148,6 @@ async fn task_main( info!("accepted postgres client connection from {peer_addr}"); let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); let tls_config = Arc::clone(&tls_config); let dest_suffix = Arc::clone(&dest_suffix); @@ -161,7 +159,7 @@ async fn task_main( .set_nodelay(true) .context("failed to set socket option")?; - handle_client(dest_suffix, tls_config, &cancel_map, session_id, socket).await + handle_client(dest_suffix, tls_config, session_id, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -186,59 +184,68 @@ async fn task_main( Ok(()) } +const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; + +async fn ssl_handshake( + raw_stream: S, + tls_config: Arc, +) -> anyhow::Result> { + let mut stream = PqStream::new(Stream::from_raw(raw_stream)); + + let msg = stream.read_startup_packet().await?; + info!("received {msg:?}"); + use pq_proto::FeStartupPacket::*; + + match msg { + SslRequest => { + stream + .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) + .await?; + // Upgrade raw stream into a secure TLS-backed stream. + // NOTE: We've consumed `tls`; this fact will be used later. + + let (raw, read_buf) = stream.into_inner(); + // TODO: Normally, client doesn't send any data before + // server says TLS handshake is ok and read_buf is empy. + // However, you could imagine pipelining of postgres + // SSLRequest + TLS ClientHello in one hunk similar to + // pipelining in our node js driver. We should probably + // support that by chaining read_buf with the stream. + if !read_buf.is_empty() { + bail!("data is sent before server replied with EncryptionResponse"); + } + Ok(raw.upgrade(tls_config).await?) + } + _ => stream.throw_error_str(ERR_INSECURE_CONNECTION).await?, + } +} + #[tracing::instrument(fields(session_id = ?session_id), skip_all)] async fn handle_client( dest_suffix: Arc, - tls: Arc, - cancel_map: &CancelMap, + tls_config: Arc, session_id: uuid::Uuid, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let do_handshake = proxy::proxy::handshake(stream, Some(tls), cancel_map); - let (mut stream, params) = match do_handshake.await? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request - }; - - let password = AuthFlow::new(&mut stream) - .begin(auth::CleartextPassword) - .await? - .authenticate() - .await?; - - let mut conn_cfg = ConnCfg::new(); - conn_cfg.set_startup_params(¶ms); - conn_cfg.password(password); + let tls_stream = ssl_handshake(stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = stream.get_ref().sni_hostname().unwrap(); + let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? .0 .splitn(3, "--") .collect(); - let destination = format!("{}.{}.{}", dest[0], dest[1], dest_suffix); let port = dest[2].parse::().context("invalid port")?; + let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port); - info!("destination: {:?}:{}", destination, port); - conn_cfg.host(destination.as_str()); - conn_cfg.port(port); + info!("destination: {}", destination); - let mut conn = conn_cfg - .connect() - .or_else(|e| stream.throw_error(e)) - .await?; + let client = tokio::net::TcpStream::connect(destination).await?; - cancel_map - .with_session(|session| async { - proxy::proxy::prepare_client_connection(&conn, false, session, &mut stream).await?; - let (stream, read_buf) = stream.into_inner(); - conn.stream.write_all(&read_buf).await?; - let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::proxy_pass(stream, conn.stream, &metrics_aux).await - }) - .await + let metrics_aux: MetricsAuxInfo = Default::default(); + proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await }