From ec8dcc223167aad145cc8b70cc3ac6801f0ed79c Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 29 Jan 2024 17:38:03 +0000 Subject: [PATCH] flatten proxy flow (#6447) ## Problem Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and doing a bit less radical changes. smaller commits. Proxy flow was quite deeply nested, which makes adding more interesting error handling quite tricky. ## Summary of changes I recommend reviewing commit by commit. 1. move handshake logic into a separate file 2. move passthrough logic into a separate file 3. no longer accept a closure in CancelMap session logic 4. Remove connect_to_db, copy logic into handle_client 5. flatten auth_and_wake_compute in authenticate 6. record info for link auth --- proxy/src/auth/backend.rs | 26 +- proxy/src/auth/backend/link.rs | 6 + proxy/src/auth/credentials.rs | 8 +- proxy/src/bin/pg_sni_router.rs | 2 +- proxy/src/cancellation.rs | 93 +++---- proxy/src/context.rs | 12 +- proxy/src/proxy.rs | 351 ++++++-------------------- proxy/src/proxy/handshake.rs | 96 +++++++ proxy/src/proxy/passthrough.rs | 57 +++++ proxy/src/serverless.rs | 2 +- proxy/src/serverless/sql_over_http.rs | 2 +- proxy/src/serverless/websocket.rs | 2 +- 12 files changed, 297 insertions(+), 360 deletions(-) create mode 100644 proxy/src/proxy/handshake.rs create mode 100644 proxy/src/proxy/passthrough.rs diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index b1634906c9..4b8ebae86f 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -190,7 +190,10 @@ async fn auth_quirks( Err(info) => { let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer) .await?; - ctx.set_endpoint_id(Some(res.info.endpoint.clone())); + + ctx.set_endpoint_id(res.info.endpoint.clone()); + tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint)); + (res.info, Some(res.keys)) } Ok(info) => (info, None), @@ -271,19 +274,12 @@ async fn authenticate_with_secret( classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await } -/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache) -/// only if authentication was successfuly. -async fn auth_and_wake_compute( +/// wake a compute (or retrieve an existing compute session from cache) +async fn wake_compute( ctx: &mut RequestMonitoring, api: &impl console::Api, - user_info: ComputeUserInfoMaybeEndpoint, - client: &mut stream::PqStream>, - allow_cleartext: bool, - config: &'static AuthenticationConfig, + compute_credentials: ComputeCredentials, ) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> { - let compute_credentials = - auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?; - let mut num_retries = 0; let mut node = loop { let wake_res = api.wake_compute(ctx, &compute_credentials.info).await; @@ -358,16 +354,16 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let (cache_info, user_info) = - auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config) - .await?; + let compute_credentials = + auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?; + let (cache_info, user_info) = wake_compute(ctx, &*api, compute_credentials).await?; (cache_info, BackendType::Console(api, user_info)) } // NOTE: this auth backend doesn't use client credentials. Link(url) => { info!("performing link authentication"); - let node_info = link::authenticate(&url, client).await?; + let node_info = link::authenticate(ctx, &url, client).await?; ( CachedNodeInfo::new_uncached(node_info), diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index a7ddd257b3..d8ae362c03 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,6 +1,7 @@ use crate::{ auth, compute, console::{self, provider::NodeInfo}, + context::RequestMonitoring, error::UserFacingError, stream::PqStream, waiters, @@ -54,6 +55,7 @@ pub fn new_psql_session_id() -> String { } pub(super) async fn authenticate( + ctx: &mut RequestMonitoring, link_uri: &reqwest::Url, client: &mut PqStream, ) -> auth::Result { @@ -94,6 +96,10 @@ pub(super) async fn authenticate( .dbname(&db_info.dbname) .user(&db_info.user); + ctx.set_user(db_info.user.into()); + ctx.set_project(db_info.aux.clone()); + tracing::Span::current().record("ep", &tracing::field::display(&db_info.aux.endpoint_id)); + // Backwards compatibility. pg_sni_proxy uses "--" in domain names // while direct connections do not. Once we migrate to pg_sni_proxy // everywhere, we can remove this. diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 5bf7667a1f..875baaec47 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -126,7 +126,11 @@ impl ComputeUserInfoMaybeEndpoint { }), } .transpose()?; - ctx.set_endpoint_id(endpoint.clone()); + + if let Some(ep) = &endpoint { + ctx.set_endpoint_id(ep.clone()); + tracing::Span::current().record("ep", &tracing::field::display(ep)); + } info!(%user, project = endpoint.as_deref(), "credentials"); if sni.is_some() { @@ -150,7 +154,7 @@ impl ComputeUserInfoMaybeEndpoint { Ok(Self { user, - endpoint_id: endpoint.map(EndpointId::from), + endpoint_id: endpoint, options, }) } diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 1edbc1e7e7..471be7af25 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -272,5 +272,5 @@ async fn handle_client( let client = tokio::net::TcpStream::connect(destination).await?; let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await + proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a5eb3544b4..d4ee657144 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,7 +1,7 @@ -use anyhow::{bail, Context}; +use anyhow::Context; use dashmap::DashMap; use pq_proto::CancelKeyData; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpStream; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; @@ -25,39 +25,31 @@ impl CancelMap { } /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result - where - F: FnOnce(Session<'a>) -> R, - R: std::future::Future>, - { + pub fn get_session(self: Arc) -> Session { // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the // actual backend_pid, but backend_pid is not used for anything // so it doesn't matter. - let key = rand::random(); + let key = loop { + let key = rand::random(); - // Random key collisions are unlikely to happen here, but they're still possible, - // which is why we have to take care not to rewrite an existing key. - match self.0.entry(key) { - dashmap::mapref::entry::Entry::Occupied(_) => { - bail!("query cancellation key already exists: {key}") + // Random key collisions are unlikely to happen here, but they're still possible, + // which is why we have to take care not to rewrite an existing key. + match self.0.entry(key) { + dashmap::mapref::entry::Entry::Occupied(_) => continue, + dashmap::mapref::entry::Entry::Vacant(e) => { + e.insert(None); + } } - dashmap::mapref::entry::Entry::Vacant(e) => { - e.insert(None); - } - } - - // This will guarantee that the session gets dropped - // as soon as the future is finished. - scopeguard::defer! { - self.0.remove(&key); - info!("dropped query cancellation key {key}"); - } + break key; + }; info!("registered new query cancellation key {key}"); - let session = Session::new(key, self); - f(session).await + Session { + key, + cancel_map: self, + } } #[cfg(test)] @@ -98,23 +90,17 @@ impl CancelClosure { } /// Helper for registering query cancellation tokens. -pub struct Session<'a> { +pub struct Session { /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancel_map: &'a CancelMap, + cancel_map: Arc, } -impl<'a> Session<'a> { - fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self { - Self { key, cancel_map } - } -} - -impl Session<'_> { +impl Session { /// Store the cancel token for the given session. /// This enables query cancellation in `crate::proxy::prepare_client_connection`. - pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { + pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); self.cancel_map.0.insert(self.key, Some(cancel_closure)); @@ -122,37 +108,26 @@ impl Session<'_> { } } +impl Drop for Session { + fn drop(&mut self) { + self.cancel_map.0.remove(&self.key); + info!("dropped query cancellation key {}", &self.key); + } +} + #[cfg(test)] mod tests { use super::*; - use once_cell::sync::Lazy; #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - static CANCEL_MAP: Lazy = Lazy::new(Default::default); - - let (tx, rx) = tokio::sync::oneshot::channel(); - let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move { - assert!(CANCEL_MAP.contains(&session)); - - tx.send(()).expect("failed to send"); - futures::future::pending::<()>().await; // sleep forever - - Ok(()) - })); - - // Wait until the task has been spawned. - rx.await.context("failed to hear from the task")?; - - // Drop the session's entry by cancelling the task. - task.abort(); - let error = task.await.expect_err("task should have failed"); - if !error.is_cancelled() { - anyhow::bail!(error); - } + let cancel_map: Arc = Default::default(); + let session = cancel_map.clone().get_session(); + assert!(cancel_map.contains(&session)); + drop(session); // Check that the session has been dropped. - assert!(CANCEL_MAP.is_empty()); + assert!(cancel_map.is_empty()); Ok(()) } diff --git a/proxy/src/context.rs b/proxy/src/context.rs index ed2ed5e367..e2b0294cd3 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -89,13 +89,11 @@ impl RequestMonitoring { self.project = Some(x.project_id); } - pub fn set_endpoint_id(&mut self, endpoint_id: Option) { - self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone()); - if let Some(ep) = &self.endpoint_id { - crate::metrics::CONNECTING_ENDPOINTS - .with_label_values(&[self.protocol]) - .measure(&ep); - } + pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) { + crate::metrics::CONNECTING_ENDPOINTS + .with_label_values(&[self.protocol]) + .measure(&endpoint_id); + self.endpoint_id = Some(endpoint_id); } pub fn set_application(&mut self, app: Option) { diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 087cc7f7a9..4aa1f3590d 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -2,37 +2,34 @@ mod tests; pub mod connect_compute; +pub mod handshake; +pub mod passthrough; pub mod retry; use crate::{ auth, cancellation::{self, CancelMap}, compute, - config::{AuthenticationConfig, ProxyConfig, TlsConfig}, - console::messages::MetricsAuxInfo, + config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, - metrics::{ - NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER, - NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE, - }, + metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE}, protocol2::WithClientIp, + proxy::{handshake::handshake, passthrough::proxy_pass}, rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, - usage_metrics::{Ids, USAGE_METRICS}, EndpointCacheKey, }; use anyhow::{bail, Context}; use futures::TryFutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; +use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; use smol_str::{format_smolstr, SmolStr}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, Instrument}; -use utils::measured_stream::MeasuredStream; use self::connect_compute::{connect_to_compute, TcpMechanism}; @@ -80,6 +77,13 @@ pub async fn task_main( let cancel_map = Arc::clone(&cancel_map); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let session_span = info_span!( + "handle_client", + ?session_id, + peer_addr = tracing::field::Empty, + ep = tracing::field::Empty, + ); + connections.spawn( async move { info!("accepted postgres client connection"); @@ -103,22 +107,18 @@ pub async fn task_main( handle_client( config, &mut ctx, - &cancel_map, + cancel_map, socket, ClientMode::Tcp, endpoint_rate_limiter, ) .await } - .instrument(info_span!( - "handle_client", - ?session_id, - peer_addr = tracing::field::Empty - )) .unwrap_or_else(move |e| { // Acknowledge that the task has finished with an error. - error!(?session_id, "per-client task finished with an error: {e:#}"); - }), + error!("per-client task finished with an error: {e:#}"); + }) + .instrument(session_span), ); } @@ -171,7 +171,7 @@ impl ClientMode { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancel_map: &CancelMap, + cancel_map: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, @@ -192,138 +192,88 @@ pub async fn handle_client( let tls = config.tls_config.as_ref(); let pause = ctx.latency_timer.pause(); - let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map); + let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map); let (mut stream, params) = match do_handshake.await? { Some(x) => x, None => return Ok(()), // it's a cancellation request }; drop(pause); + let hostname = mode.hostname(stream.get_ref()); + + let common_names = tls.map(|tls| &tls.common_names); + // Extract credentials which we're going to use for auth. - let user_info = { - let hostname = mode.hostname(stream.get_ref()); + let result = config + .auth_backend + .as_ref() + .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) + .transpose(); - let common_names = tls.map(|tls| &tls.common_names); - let result = config - .auth_backend - .as_ref() - .map(|_| { - auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names) - }) - .transpose(); + let user_info = match result { + Ok(user_info) => user_info, + Err(e) => stream.throw_error(e).await?, + }; - match result { - Ok(user_info) => user_info, - Err(e) => stream.throw_error(e).await?, + // check rate limit + if let Some(ep) = user_info.get_endpoint() { + if !endpoint_rate_limiter.check(ep) { + return stream + .throw_error(auth::AuthError::too_many_connections()) + .await; + } + } + + let user = user_info.get_user().to_owned(); + let (mut node_info, user_info) = match user_info + .authenticate( + ctx, + &mut stream, + mode.allow_cleartext(), + &config.authentication_config, + ) + .await + { + Ok(auth_result) => auth_result, + Err(e) => { + let db = params.get("database"); + let app = params.get("application_name"); + let params_span = tracing::info_span!("", ?user, ?db, ?app); + + return stream.throw_error(e).instrument(params_span).await; } }; - ctx.set_endpoint_id(user_info.get_endpoint()); + node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config); - let client = Client::new( - stream, - user_info, - ¶ms, - mode.allow_self_signed_compute(config), - endpoint_rate_limiter, - ); - cancel_map - .with_session(|session| { - client.connect_to_db(ctx, session, mode, &config.authentication_config) - }) - .await -} + let aux = node_info.aux.clone(); + let mut node = connect_to_compute( + ctx, + &TcpMechanism { params: ¶ms }, + node_info, + &user_info, + ) + .or_else(|e| stream.throw_error(e)) + .await?; -/// Establish a (most probably, secure) connection with the client. -/// For better testing experience, `stream` can be any object satisfying the traits. -/// It's easier to work with owned `stream` here as we need to upgrade it to TLS; -/// we also take an extra care of propagating only the select handshake errors to client. -#[tracing::instrument(skip_all)] -async fn handshake( - stream: S, - mut tls: Option<&TlsConfig>, - cancel_map: &CancelMap, -) -> anyhow::Result>, StartupMessageParams)>> { - // Client may try upgrading to each protocol only once - let (mut tried_ssl, mut tried_gss) = (false, false); + let session = cancel_map.get_session(); + prepare_client_connection(&node, &session, &mut stream).await?; - let mut stream = PqStream::new(Stream::from_raw(stream)); - loop { - let msg = stream.read_startup_packet().await?; - info!("received {msg:?}"); + // Before proxy passing, forward to compute whatever data is left in the + // PqStream input buffer. Normally there is none, but our serverless npm + // driver in pipeline mode sends startup, password and first query + // immediately after opening the connection. + let (stream, read_buf) = stream.into_inner(); + node.stream.write_all(&read_buf).await?; - use FeStartupPacket::*; - match msg { - SslRequest => match stream.get_ref() { - Stream::Raw { .. } if !tried_ssl => { - tried_ssl = true; - - // We can't perform TLS handshake without a config - let enc = tls.is_some(); - stream.write_message(&Be::EncryptionResponse(enc)).await?; - if let Some(tls) = tls.take() { - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empy. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - let tls_stream = raw.upgrade(tls.to_server_config()).await?; - - let (_, tls_server_end_point) = tls - .cert_resolver - .resolve(tls_stream.get_ref().1.server_name()) - .context("missing certificate")?; - - stream = PqStream::new(Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }); - } - } - _ => bail!(ERR_PROTO_VIOLATION), - }, - GssEncRequest => match stream.get_ref() { - Stream::Raw { .. } if !tried_gss => { - tried_gss = true; - - // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; - } - _ => bail!(ERR_PROTO_VIOLATION), - }, - StartupMessage { params, .. } => { - // Check that the config has been consumed during upgrade - // OR we didn't provide it at all (for dev purposes). - if tls.is_some() { - stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; - } - - info!(session_type = "normal", "successful handshake"); - break Ok(Some((stream, params))); - } - CancelRequest(cancel_key_data) => { - cancel_map.cancel_session(cancel_key_data).await?; - - info!(session_type = "cancellation", "successful handshake"); - break Ok(None); - } - } - } + proxy_pass(ctx, stream, node.stream, aux).await } /// Finish client connection initialization: confirm auth success, send params, etc. #[tracing::instrument(skip_all)] async fn prepare_client_connection( node: &compute::PostgresConnection, - session: cancellation::Session<'_>, + session: &cancellation::Session, stream: &mut PqStream, ) -> anyhow::Result<()> { // Register compute's query cancellation token and produce a new, unique one. @@ -349,151 +299,6 @@ async fn prepare_client_connection( Ok(()) } -/// Forward bytes in both directions (client <-> compute). -#[tracing::instrument(skip_all)] -pub async fn proxy_pass( - ctx: &mut RequestMonitoring, - client: impl AsyncRead + AsyncWrite + Unpin, - compute: impl AsyncRead + AsyncWrite + Unpin, - aux: MetricsAuxInfo, -) -> anyhow::Result<()> { - ctx.set_success(); - ctx.log(); - - let usage = USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id.clone(), - branch_id: aux.branch_id.clone(), - }); - - let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]); - let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx")); - let mut client = MeasuredStream::new( - client, - |_| {}, - |cnt| { - // Number of bytes we sent to the client (outbound). - m_sent.inc_by(cnt as u64); - m_sent2.inc_by(cnt as u64); - usage.record_egress(cnt as u64); - }, - ); - - let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]); - let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx")); - let mut compute = MeasuredStream::new( - compute, - |_| {}, - |cnt| { - // Number of bytes the client sent to the compute node (inbound). - m_recv.inc_by(cnt as u64); - m_recv2.inc_by(cnt as u64); - }, - ); - - // Starting from here we only proxy the client's traffic. - info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; - - Ok(()) -} - -/// Thin connection context. -struct Client<'a, S> { - /// The underlying libpq protocol stream. - stream: PqStream>, - /// Client credentials that we care about. - user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>, - /// KV-dictionary with PostgreSQL connection params. - params: &'a StartupMessageParams, - /// Allow self-signed certificates (for testing). - allow_self_signed_compute: bool, - /// Rate limiter for endpoints - endpoint_rate_limiter: Arc, -} - -impl<'a, S> Client<'a, S> { - /// Construct a new connection context. - fn new( - stream: PqStream>, - user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>, - params: &'a StartupMessageParams, - allow_self_signed_compute: bool, - endpoint_rate_limiter: Arc, - ) -> Self { - Self { - stream, - user_info, - params, - allow_self_signed_compute, - endpoint_rate_limiter, - } - } -} - -impl Client<'_, S> { - /// Let the client authenticate and connect to the designated compute node. - // Instrumentation logs endpoint name everywhere. Doesn't work for link - // auth; strictly speaking we don't know endpoint name in its case. - #[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)] - async fn connect_to_db( - self, - ctx: &mut RequestMonitoring, - session: cancellation::Session<'_>, - mode: ClientMode, - config: &'static AuthenticationConfig, - ) -> anyhow::Result<()> { - let Self { - mut stream, - user_info, - params, - allow_self_signed_compute, - endpoint_rate_limiter, - } = self; - - // check rate limit - if let Some(ep) = user_info.get_endpoint() { - if !endpoint_rate_limiter.check(ep) { - return stream - .throw_error(auth::AuthError::too_many_connections()) - .await; - } - } - - let user = user_info.get_user().to_owned(); - let auth_result = match user_info - .authenticate(ctx, &mut stream, mode.allow_cleartext(), config) - .await - { - Ok(auth_result) => auth_result, - Err(e) => { - let db = params.get("database"); - let app = params.get("application_name"); - let params_span = tracing::info_span!("", ?user, ?db, ?app); - - return stream.throw_error(e).instrument(params_span).await; - } - }; - - let (mut node_info, user_info) = auth_result; - - node_info.allow_self_signed_compute = allow_self_signed_compute; - - let aux = node_info.aux.clone(); - let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info) - .or_else(|e| stream.throw_error(e)) - .await?; - - prepare_client_connection(&node, session, &mut stream).await?; - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; - proxy_pass(ctx, stream, node.stream, aux).await - } -} - #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct NeonOptions(Vec<(SmolStr, SmolStr)>); diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs new file mode 100644 index 0000000000..1ad8da20d7 --- /dev/null +++ b/proxy/src/proxy/handshake.rs @@ -0,0 +1,96 @@ +use anyhow::{bail, Context}; +use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::info; + +use crate::{ + cancellation::CancelMap, + config::TlsConfig, + proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION}, + stream::{PqStream, Stream}, +}; + +/// Establish a (most probably, secure) connection with the client. +/// For better testing experience, `stream` can be any object satisfying the traits. +/// It's easier to work with owned `stream` here as we need to upgrade it to TLS; +/// we also take an extra care of propagating only the select handshake errors to client. +#[tracing::instrument(skip_all)] +pub async fn handshake( + stream: S, + mut tls: Option<&TlsConfig>, + cancel_map: &CancelMap, +) -> anyhow::Result>, StartupMessageParams)>> { + // Client may try upgrading to each protocol only once + let (mut tried_ssl, mut tried_gss) = (false, false); + + let mut stream = PqStream::new(Stream::from_raw(stream)); + loop { + let msg = stream.read_startup_packet().await?; + info!("received {msg:?}"); + + use FeStartupPacket::*; + match msg { + SslRequest => match stream.get_ref() { + Stream::Raw { .. } if !tried_ssl => { + tried_ssl = true; + + // We can't perform TLS handshake without a config + let enc = tls.is_some(); + stream.write_message(&Be::EncryptionResponse(enc)).await?; + if let Some(tls) = tls.take() { + // Upgrade raw stream into a secure TLS-backed stream. + // NOTE: We've consumed `tls`; this fact will be used later. + + let (raw, read_buf) = stream.into_inner(); + // TODO: Normally, client doesn't send any data before + // server says TLS handshake is ok and read_buf is empy. + // However, you could imagine pipelining of postgres + // SSLRequest + TLS ClientHello in one hunk similar to + // pipelining in our node js driver. We should probably + // support that by chaining read_buf with the stream. + if !read_buf.is_empty() { + bail!("data is sent before server replied with EncryptionResponse"); + } + let tls_stream = raw.upgrade(tls.to_server_config()).await?; + + let (_, tls_server_end_point) = tls + .cert_resolver + .resolve(tls_stream.get_ref().1.server_name()) + .context("missing certificate")?; + + stream = PqStream::new(Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, + }); + } + } + _ => bail!(ERR_PROTO_VIOLATION), + }, + GssEncRequest => match stream.get_ref() { + Stream::Raw { .. } if !tried_gss => { + tried_gss = true; + + // Currently, we don't support GSSAPI + stream.write_message(&Be::EncryptionResponse(false)).await?; + } + _ => bail!(ERR_PROTO_VIOLATION), + }, + StartupMessage { params, .. } => { + // Check that the config has been consumed during upgrade + // OR we didn't provide it at all (for dev purposes). + if tls.is_some() { + stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; + } + + info!(session_type = "normal", "successful handshake"); + break Ok(Some((stream, params))); + } + CancelRequest(cancel_key_data) => { + cancel_map.cancel_session(cancel_key_data).await?; + + info!(session_type = "cancellation", "successful handshake"); + break Ok(None); + } + } + } +} diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs new file mode 100644 index 0000000000..d6f097d72d --- /dev/null +++ b/proxy/src/proxy/passthrough.rs @@ -0,0 +1,57 @@ +use crate::{ + console::messages::MetricsAuxInfo, + context::RequestMonitoring, + metrics::{NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER}, + usage_metrics::{Ids, USAGE_METRICS}, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::info; +use utils::measured_stream::MeasuredStream; + +/// Forward bytes in both directions (client <-> compute). +#[tracing::instrument(skip_all)] +pub async fn proxy_pass( + ctx: &mut RequestMonitoring, + client: impl AsyncRead + AsyncWrite + Unpin, + compute: impl AsyncRead + AsyncWrite + Unpin, + aux: MetricsAuxInfo, +) -> anyhow::Result<()> { + ctx.set_success(); + ctx.log(); + + let usage = USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id.clone(), + branch_id: aux.branch_id.clone(), + }); + + let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]); + let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx")); + let mut client = MeasuredStream::new( + client, + |_| {}, + |cnt| { + // Number of bytes we sent to the client (outbound). + m_sent.inc_by(cnt as u64); + m_sent2.inc_by(cnt as u64); + usage.record_egress(cnt as u64); + }, + ); + + let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]); + let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx")); + let mut compute = MeasuredStream::new( + compute, + |_| {}, + |cnt| { + // Number of bytes the client sent to the compute node (inbound). + m_recv.inc_by(cnt as u64); + m_recv2.inc_by(cnt as u64); + }, + ); + + // Starting from here we only proxy the client's traffic. + info!("performing the proxy pass..."); + let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; + + Ok(()) +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index dfef4ccdfa..a2eb7e62cc 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -230,7 +230,7 @@ async fn request_handler( config, &mut ctx, websocket, - &cancel_map, + cancel_map, host, endpoint_rate_limiter, ) diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1e2ddaa2ff..27c2134221 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -189,7 +189,7 @@ fn get_conn_info( } let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?; - ctx.set_endpoint_id(Some(endpoint.clone())); + ctx.set_endpoint_id(endpoint.clone()); let pairs = connection_url.query_pairs(); diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index a6529c920a..f68b35010a 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -133,7 +133,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, websocket: HyperWebsocket, - cancel_map: &CancelMap, + cancel_map: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> {