From debd134b15083ebd0587b760232724e0a644af31 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Fri, 6 Jan 2023 19:34:18 +0400 Subject: [PATCH] Implement wss support in proxy (#3247) This is a hacky implementation of WebSocket server, embedded into our postgres proxy. The server is used to allow https://github.com/neondatabase/serverless to connect to our postgres from browser and serverless javascript functions. How it will work (general schema): - browser opens a websocket connection to `wss://ep-abc-xyz-123.xx-central-1.aws.neon.tech/` - proxy accepts this connection and terminates TLS (https) - inside encrypted tunnel (HTTPS), browser initiates plain (non-encrypted) postgres connection - proxy performs auth as in usual plain pg connection and forwards connection to the compute Related issue: #3225 --- Cargo.lock | 79 ++++++++++ proxy/Cargo.toml | 4 + proxy/src/auth/backend.rs | 32 ++++- proxy/src/auth/credentials.rs | 23 ++- proxy/src/auth/flow.rs | 23 +++ proxy/src/http.rs | 1 + proxy/src/http/websocket.rs | 263 ++++++++++++++++++++++++++++++++++ proxy/src/main.rs | 22 ++- proxy/src/proxy.rs | 43 +++++- 9 files changed, 476 insertions(+), 14 deletions(-) create mode 100644 proxy/src/http/websocket.rs diff --git a/Cargo.lock b/Cargo.lock index fbf018e1c0..284a111ba7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1700,6 +1700,19 @@ dependencies = [ "tokio-io-timeout", ] +[[package]] +name = "hyper-tungstenite" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d62004bcd4f6f85d9e2aa4206f1466ee67031f5ededcb6c6e62d48f9306ad879" +dependencies = [ + "hyper", + "pin-project", + "tokio", + "tokio-tungstenite", + "tungstenite", +] + [[package]] name = "iana-time-zone" version = "0.1.53" @@ -2658,6 +2671,7 @@ dependencies = [ "hex", "hmac", "hyper", + "hyper-tungstenite", "itertools", "md5", "metrics", @@ -2667,6 +2681,7 @@ dependencies = [ "pq_proto", "rand", "rcgen", + "regex", "reqwest", "routerify", "rstest", @@ -2678,6 +2693,7 @@ dependencies = [ "sha2", "socket2", "thiserror", + "tls-listener", "tokio", "tokio-postgres", "tokio-postgres-rustls", @@ -2687,6 +2703,7 @@ dependencies = [ "url", "utils", "uuid", + "webpki-roots", "workspace_hack", "x509-parser", ] @@ -3324,6 +3341,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha-1" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha1" version = "0.10.5" @@ -3687,6 +3715,20 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +[[package]] +name = "tls-listener" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9d4ff21187d434ac7709bfc7441ca88f63681247e5ad99f0f08c8c91ddc103d" +dependencies = [ + "futures-util", + "hyper", + "pin-project-lite", + "thiserror", + "tokio", + "tokio-rustls", +] + [[package]] name = "tokio" version = "1.21.1" @@ -3801,6 +3843,18 @@ dependencies = [ "xattr", ] +[[package]] +name = "tokio-tungstenite" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.4" @@ -4027,6 +4081,25 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "tungstenite" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0" +dependencies = [ + "base64 0.13.1", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "sha-1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.16.0" @@ -4115,6 +4188,12 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utils" version = "0.1.0" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0bf47c7b88..cbc067093e 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -17,12 +17,14 @@ hashbrown = "0.12" hex = "0.4.3" hmac = "0.12.1" hyper = "0.14" +hyper-tungstenite = "0.8.1" itertools = "0.10.3" md5 = "0.7.0" once_cell = "1.13.0" parking_lot = "0.12" pin-project-lite = "0.2.7" rand = "0.8.3" +regex = "1.4.5" reqwest = { version = "0.11", default-features = false, features = [ "json", "rustls-tls" ] } routerify = "3" rustls = "0.20.0" @@ -36,10 +38,12 @@ thiserror = "1.0.30" tokio = { version = "1.17", features = ["macros"] } tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" } tokio-rustls = "0.23.0" +tls-listener = { version = "0.5.1", features = ["rustls", "hyper-h1"] } tracing = "0.1.36" tracing-subscriber = { version = "0.3", features = ["env-filter"] } url = "2.2.2" uuid = { version = "1.2", features = ["v4", "serde"] } +webpki-roots = "0.22.5" x509-parser = "0.14" metrics = { path = "../libs/metrics" } diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 4adf0ed940..e6a179a040 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -149,7 +149,7 @@ impl BackendType<'_, ClientCredentials<'_>> { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the project name. // We now expect to see a very specific payload in the place of password. - let fetch_magic_payload = async { + let fetch_magic_payload = |client| async { warn!("project name not specified, resorting to the password hack auth flow"); let payload = AuthFlow::new(client) .begin(auth::PasswordHack) @@ -161,10 +161,26 @@ impl BackendType<'_, ClientCredentials<'_>> { auth::Result::Ok(payload) }; + // If we want to use cleartext password flow, we can read the password + // from the client and pretend that it's a magic payload (PasswordHack hack). + let fetch_plaintext_password = |client| async { + info!("using cleartext password flow"); + let payload = AuthFlow::new(client) + .begin(auth::CleartextPassword) + .await? + .authenticate() + .await?; + + auth::Result::Ok(auth::password_hack::PasswordHackPayload { + project: String::new(), + password: payload, + }) + }; + // TODO: find a proper way to merge those very similar blocks. let (mut node, payload) = match self { Console(endpoint, creds) if creds.project.is_none() => { - let payload = fetch_magic_payload.await?; + let payload = fetch_magic_payload(client).await?; let mut creds = creds.as_ref(); creds.project = Some(payload.project.as_str().into()); @@ -174,8 +190,18 @@ impl BackendType<'_, ClientCredentials<'_>> { (node, payload) } + Console(endpoint, creds) if creds.use_cleartext_password_flow => { + // This is a hack to allow cleartext password in secure connections (wss). + let payload = fetch_plaintext_password(client).await?; + let creds = creds.as_ref(); + let node = console::Api::new(endpoint, extra, &creds) + .wake_compute() + .await?; + + (node, payload) + } Postgres(endpoint, creds) if creds.project.is_none() => { - let payload = fetch_magic_payload.await?; + let payload = fetch_magic_payload(client).await?; let mut creds = creds.as_ref(); creds.project = Some(payload.project.as_str().into()); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 0a3b84bb52..3b71bef9aa 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -34,6 +34,9 @@ pub struct ClientCredentials<'a> { pub user: &'a str, pub dbname: &'a str, pub project: Option>, + /// If `True`, we'll use the old cleartext password flow. This is used for + /// websocket connections, which want to minimize the number of round trips. + pub use_cleartext_password_flow: bool, } impl ClientCredentials<'_> { @@ -50,6 +53,7 @@ impl<'a> ClientCredentials<'a> { user: self.user, dbname: self.dbname, project: self.project().map(Cow::Borrowed), + use_cleartext_password_flow: self.use_cleartext_password_flow, } } } @@ -59,6 +63,7 @@ impl<'a> ClientCredentials<'a> { params: &'a StartupMessageParams, sni: Option<&str>, common_name: Option<&str>, + use_cleartext_password_flow: bool, ) -> Result { use ClientCredsParseError::*; @@ -108,6 +113,7 @@ impl<'a> ClientCredentials<'a> { user = user, dbname = dbname, project = project.as_deref(), + use_cleartext_password_flow = use_cleartext_password_flow, "credentials" ); @@ -115,6 +121,7 @@ impl<'a> ClientCredentials<'a> { user, dbname, project, + use_cleartext_password_flow, }) } } @@ -141,7 +148,7 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe")]); // TODO: check that `creds.dbname` is None. - let creds = ClientCredentials::parse(&options, None, None)?; + let creds = ClientCredentials::parse(&options, None, None, false)?; assert_eq!(creds.user, "john_doe"); Ok(()) @@ -151,7 +158,7 @@ mod tests { fn parse_missing_project() -> anyhow::Result<()> { let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]); - let creds = ClientCredentials::parse(&options, None, None)?; + let creds = ClientCredentials::parse(&options, None, None, false)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project, None); @@ -166,7 +173,7 @@ mod tests { let sni = Some("foo.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni, common_name, false)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("foo")); @@ -182,7 +189,7 @@ mod tests { ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let creds = ClientCredentials::parse(&options, None, None)?; + let creds = ClientCredentials::parse(&options, None, None, false)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -201,7 +208,7 @@ mod tests { let sni = Some("baz.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni, common_name, false)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("baz")); @@ -220,7 +227,8 @@ mod tests { let sni = Some("second.localhost"); let common_name = Some("localhost"); - let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); + let err = + ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -237,7 +245,8 @@ mod tests { let sni = Some("project.localhost"); let common_name = Some("example.com"); - let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); + let err = + ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail"); match err { InconsistentSni { sni, cn } => { assert_eq!(sni, "project.localhost"); diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index d9ee50894d..4b982c0c5e 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -37,6 +37,17 @@ impl AuthMethod for PasswordHack { } } +/// Use clear-text password auth called `password` in docs +/// +pub struct CleartextPassword; + +impl AuthMethod for CleartextPassword { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationCleartextPassword + } +} + /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub struct AuthFlow<'a, Stream, State> { @@ -86,6 +97,18 @@ impl AuthFlow<'_, S, PasswordHack> { } } +impl AuthFlow<'_, S, CleartextPassword> { + /// Perform user authentication. Raise an error in case authentication failed. + pub async fn authenticate(self) -> super::Result> { + let msg = self.stream.read_password_message().await?; + let password = msg + .strip_suffix(&[0]) + .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?; + + Ok(password.to_vec()) + } +} + /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 096a33d73d..e847edc8bd 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -1,4 +1,5 @@ pub mod server; +pub mod websocket; use crate::url::ApiUrl; diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs new file mode 100644 index 0000000000..33c2752307 --- /dev/null +++ b/proxy/src/http/websocket.rs @@ -0,0 +1,263 @@ +use bytes::{Buf, Bytes}; +use futures::{Sink, Stream, StreamExt}; +use hyper::server::accept::{self}; +use hyper::server::conn::AddrIncoming; +use hyper::upgrade::Upgraded; +use hyper::{Body, Request, Response, StatusCode}; +use hyper_tungstenite::{tungstenite, WebSocketStream}; +use hyper_tungstenite::{tungstenite::Message, HyperWebsocket}; +use pin_project_lite::pin_project; +use tokio::net::TcpListener; + +use std::convert::Infallible; +use std::future::ready; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tls_listener::TlsListener; + +use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; + +use tracing::{error, info, info_span, warn, Instrument}; +use utils::http::{error::ApiError, json::json_response}; + +use crate::cancellation::CancelMap; +use crate::config::ProxyConfig; +use crate::proxy::handle_ws_client; + +pin_project! { + /// This is a wrapper around a WebSocketStream that implements AsyncRead and AsyncWrite. + pub struct WebSocketRW { + #[pin] + stream: WebSocketStream, + chunk: Option, + } +} + +// FIXME: explain why this is safe or try to remove `unsafe impl`. +unsafe impl Sync for WebSocketRW {} + +impl WebSocketRW { + pub fn new(stream: WebSocketStream) -> Self { + Self { + stream, + chunk: None, + } + } + + fn has_chunk(&self) -> bool { + if let Some(ref chunk) = self.chunk { + chunk.remaining() > 0 + } else { + false + } + } +} + +fn ws_err_into(e: tungstenite::Error) -> io::Error { + io::Error::new(io::ErrorKind::Other, e.to_string()) +} + +impl AsyncWrite for WebSocketRW { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + match this.stream.as_mut().poll_ready(cx) { + Poll::Ready(Ok(())) => { + if let Err(e) = this + .stream + .as_mut() + .start_send(Message::Binary(buf.to_vec())) + { + Poll::Ready(Err(ws_err_into(e))) + } else { + Poll::Ready(Ok(buf.len())) + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(ws_err_into(e))), + Poll::Pending => { + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_flush(cx).map_err(ws_err_into) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_close(cx).map_err(ws_err_into) + } +} + +impl AsyncRead for WebSocketRW { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.remaining()); + buf.put_slice(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(())) + } +} + +impl AsyncBufRead for WebSocketRW { + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.as_mut().has_chunk() { + let buf = self.project().chunk.as_ref().unwrap().chunk(); + return Poll::Ready(Ok(buf)); + } else { + match self.as_mut().project().stream.poll_next(cx) { + Poll::Ready(Some(Ok(message))) => match message { + Message::Text(_) => {} + Message::Binary(chunk) => { + *self.as_mut().project().chunk = Some(Bytes::from(chunk)); + } + Message::Ping(_) => { + // No need to send a reply: tungstenite takes care of this for you. + } + Message::Pong(_) => {} + Message::Close(_) => { + // No need to send a reply: tungstenite takes care of this for you. + return Poll::Ready(Ok(&[])); + } + Message::Frame(_) => { + unreachable!(); + } + }, + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(ws_err_into(err))), + Poll::Ready(None) => return Poll::Ready(Ok(&[])), + Poll::Pending => return Poll::Pending, + } + } + } + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + if amt > 0 { + self.project() + .chunk + .as_mut() + .expect("No chunk present") + .advance(amt); + } + } +} + +async fn serve_websocket( + websocket: HyperWebsocket, + config: &ProxyConfig, + cancel_map: &CancelMap, + session_id: uuid::Uuid, + hostname: Option, +) -> anyhow::Result<()> { + let websocket = websocket.await?; + handle_ws_client( + config, + cancel_map, + session_id, + WebSocketRW::new(websocket), + hostname, + ) + .await?; + Ok(()) +} + +async fn ws_handler( + mut request: Request, + config: &'static ProxyConfig, + cancel_map: Arc, + session_id: uuid::Uuid, +) -> Result, ApiError> { + let host = request + .headers() + .get("host") + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.split(':').next()) + .map(|s| s.to_string()); + + // Check if the request is a websocket upgrade request. + if hyper_tungstenite::is_upgrade_request(&request) { + let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) + .map_err(|e| ApiError::BadRequest(e.into()))?; + + tokio::spawn(async move { + if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await + { + error!("error in websocket connection: {:?}", e); + } + }); + + // Return the response so the spawned future can continue. + Ok(response) + } else { + json_response(StatusCode::OK, "Connect with a websocket client") + } +} + +pub async fn task_main( + ws_listener: TcpListener, + config: &'static ProxyConfig, +) -> anyhow::Result<()> { + scopeguard::defer! { + info!("websocket server has shut down"); + } + + let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config()); + let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config { + Some(config) => config.into(), + None => { + warn!("TLS config is missing, WebSocket Secure server will not be started"); + return Ok(()); + } + }; + + let addr_incoming = AddrIncoming::from_listener(ws_listener)?; + + let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { + if let Err(err) = conn { + error!("failed to accept TLS connection for websockets: {:?}", err); + ready(false) + } else { + ready(true) + } + }); + + let make_svc = hyper::service::make_service_fn(|_stream| async move { + Ok::<_, Infallible>(hyper::service::service_fn( + move |req: Request| async move { + let cancel_map = Arc::new(CancelMap::default()); + let session_id = uuid::Uuid::new_v4(); + ws_handler(req, config, cancel_map, session_id) + .instrument(info_span!( + "ws-client", + session = format_args!("{session_id}") + )) + .await + }, + )) + }); + + hyper::Server::builder(accept::from_stream(tls_listener)) + .serve(make_svc) + .await?; + + Ok(()) +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 89ea9142a9..aa6766c102 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -110,12 +110,23 @@ async fn main() -> anyhow::Result<()> { info!("Starting proxy on {proxy_address}"); let proxy_listener = TcpListener::bind(proxy_address).await?; - let tasks = [ + let mut tasks = vec![ tokio::spawn(http::server::task_main(http_listener)), tokio::spawn(proxy::task_main(config, proxy_listener)), tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)), - ] - .map(flatten_err); + ]; + + if let Some(wss_address) = arg_matches.get_one::("wss") { + let wss_address: SocketAddr = wss_address.parse()?; + info!("Starting wss on {}", wss_address); + let wss_listener = TcpListener::bind(wss_address).await?; + tasks.push(tokio::spawn(http::websocket::task_main( + wss_listener, + config, + ))); + } + + let tasks = tasks.into_iter().map(flatten_err); set_build_info_metric(GIT_VERSION); // This will block until all tasks have completed. @@ -155,6 +166,11 @@ fn cli() -> clap::Command { .help("listen for incoming http connections (metrics, etc) on ip:port") .default_value("127.0.0.1:7001"), ) + .arg( + Arg::new("wss") + .long("wss") + .help("listen for incoming wss connections on ip:port"), + ) .arg( Arg::new("uri") .short('u') diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 382f7cd918..63573d49c0 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -82,6 +82,47 @@ pub async fn task_main( } } +pub async fn handle_ws_client( + config: &ProxyConfig, + cancel_map: &CancelMap, + session_id: uuid::Uuid, + stream: impl AsyncRead + AsyncWrite + Unpin + Send, + hostname: Option, +) -> anyhow::Result<()> { + // The `closed` counter will increase when this future is destroyed. + NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); + scopeguard::defer! { + NUM_CONNECTIONS_CLOSED_COUNTER.inc(); + } + + let tls = config.tls_config.as_ref(); + let hostname = hostname.as_deref(); + + // TLS is None here, because the connection is already encrypted. + let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake")); + let (mut stream, params) = match do_handshake.await? { + Some(x) => x, + None => return Ok(()), // it's a cancellation request + }; + + // Extract credentials which we're going to use for auth. + let creds = { + let common_name = tls.and_then(|tls| tls.common_name.as_deref()); + let result = config + .auth_backend + .as_ref() + .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name, true)) + .transpose(); + + async { result }.or_else(|e| stream.throw_error(e)).await? + }; + + let client = Client::new(stream, creds, ¶ms, session_id); + cancel_map + .with_session(|session| client.connect_to_db(session)) + .await +} + async fn handle_client( config: &ProxyConfig, cancel_map: &CancelMap, @@ -108,7 +149,7 @@ async fn handle_client( let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) + .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name, false)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await?