mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
Start passthrough earlier
As soon as we have received the SSLRequest packet, and have figured out the hostname to connect to from the SNI, we can start passing through data. We don't need to parse the StartupPacket that the client will send next.
This commit is contained in:
committed by
Stas Kelvich
parent
3813c703c9
commit
53e5d18da5
@@ -4,17 +4,17 @@
|
||||
/// This allows connecting to pods/services running in the same Kubernetes cluster from
|
||||
/// the outside. Similar to an ingress controller for HTTPS.
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::{io::AsyncWriteExt, net::TcpListener};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use tokio::net::TcpListener;
|
||||
// use tokio::net::TcpListener;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
use clap::{self, Arg};
|
||||
use futures::TryFutureExt;
|
||||
use proxy::{
|
||||
auth::{self, AuthFlow},
|
||||
cancellation::CancelMap,
|
||||
compute::ConnCfg,
|
||||
console::messages::MetricsAuxInfo,
|
||||
};
|
||||
use proxy::console::messages::MetricsAuxInfo;
|
||||
// use proxy::console::messages::MetricsAuxInfo;
|
||||
use proxy::stream::{PqStream, Stream};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
@@ -140,7 +140,6 @@ async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let mut connections = tokio::task::JoinSet::new();
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -149,7 +148,6 @@ async fn task_main(
|
||||
info!("accepted postgres client connection from {peer_addr}");
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
let tls_config = Arc::clone(&tls_config);
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
|
||||
@@ -161,7 +159,7 @@ async fn task_main(
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(dest_suffix, tls_config, &cancel_map, session_id, socket).await
|
||||
handle_client(dest_suffix, tls_config, session_id, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -186,59 +184,68 @@ async fn task_main(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
info!("received {msg:?}");
|
||||
use pq_proto::FeStartupPacket::*;
|
||||
|
||||
match msg {
|
||||
SslRequest => {
|
||||
stream
|
||||
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
|
||||
.await?;
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empy.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
Ok(raw.upgrade(tls_config).await?)
|
||||
}
|
||||
_ => stream.throw_error_str(ERR_INSECURE_CONNECTION).await?,
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(fields(session_id = ?session_id), skip_all)]
|
||||
async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls: Arc<rustls::ServerConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let do_handshake = proxy::proxy::handshake(stream, Some(tls), cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
let password = AuthFlow::new(&mut stream)
|
||||
.begin(auth::CleartextPassword)
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
let mut conn_cfg = ConnCfg::new();
|
||||
conn_cfg.set_startup_params(¶ms);
|
||||
conn_cfg.password(password);
|
||||
let tls_stream = ssl_handshake(stream, tls_config).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
|
||||
let sni = stream.get_ref().sni_hostname().unwrap();
|
||||
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
|
||||
let dest: Vec<&str> = sni
|
||||
.split_once('.')
|
||||
.context("invalid SNI")?
|
||||
.0
|
||||
.splitn(3, "--")
|
||||
.collect();
|
||||
let destination = format!("{}.{}.{}", dest[0], dest[1], dest_suffix);
|
||||
let port = dest[2].parse::<u16>().context("invalid port")?;
|
||||
let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
|
||||
|
||||
info!("destination: {:?}:{}", destination, port);
|
||||
conn_cfg.host(destination.as_str());
|
||||
conn_cfg.port(port);
|
||||
info!("destination: {}", destination);
|
||||
|
||||
let mut conn = conn_cfg
|
||||
.connect()
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
cancel_map
|
||||
.with_session(|session| async {
|
||||
proxy::proxy::prepare_client_connection(&conn, false, session, &mut stream).await?;
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
conn.stream.write_all(&read_buf).await?;
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::proxy_pass(stream, conn.stream, &metrics_aux).await
|
||||
})
|
||||
.await
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user