#[cfg(test)] mod tests; pub mod connect_compute; mod copy_bidirectional; pub mod handshake; pub mod passthrough; pub mod retry; pub mod wake_compute; pub use copy_bidirectional::copy_bidirectional_client_compute; use crate::{ auth, cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, error::ReportableError, metrics::{Metrics, NumClientConnectionsGuard}, protocol2::WithClientIp, proxy::handshake::{handshake, HandshakeData}, stream::{PqStream, Stream}, EndpointCacheKey, }; use futures::TryFutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; use smol_str::{format_smolstr, SmolStr}; use std::sync::Arc; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tracing::{error, info, Instrument}; use self::{ connect_compute::{connect_to_compute, TcpMechanism}, passthrough::ProxyPassthrough, }; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; pub async fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, ) -> Option { match futures::future::select( std::pin::pin!(f), std::pin::pin!(cancellation_token.cancelled()), ) .await { futures::future::Either::Left((f, _)) => Some(f), futures::future::Either::Right(((), _)) => None, } } pub async fn task_main( config: &'static ProxyConfig, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); } // When set for the server socket, the keepalive setting // will be inherited by all accepted client sockets. socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await { let (socket, peer_addr) = accept_result?; let conn_gauge = Metrics::get() .proxy .client_connections .guard(crate::metrics::Protocol::Tcp); let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection"); connections.spawn(async move { let mut socket = WithClientIp::new(socket); let mut peer_addr = peer_addr.ip(); match socket.wait_for_addr().await { Ok(Some(addr)) => peer_addr = addr.ip(), Err(e) => { error!("per-client task finished with an error: {e:#}"); return; } Ok(None) if config.require_client_ip => { error!("missing required client IP"); return; } Ok(None) => {} } match socket.inner.set_nodelay(true) { Ok(()) => {}, Err(e) => { error!("per-client task finished with an error: failed to set socket option: {e:#}"); return; }, }; let mut ctx = RequestMonitoring::new( session_id, peer_addr, crate::metrics::Protocol::Tcp, &config.region, ); let span = ctx.span.clone(); let res = handle_client( config, &mut ctx, cancellation_handler, socket, ClientMode::Tcp, conn_gauge, ) .instrument(span.clone()) .await; match res { Err(e) => { // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); ctx.log(); error!(parent: &span, "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); ctx.log(); } Ok(Some(p)) => { ctx.set_success(); ctx.log(); match p.proxy_pass().instrument(span.clone()).await { Ok(()) => {} Err(e) => { error!(parent: &span, "per-client task finished with an error: {e:#}"); } } } } }); } connections.close(); drop(listener); // Drain connections connections.wait().await; Ok(()) } pub enum ClientMode { Tcp, Websockets { hostname: Option }, } /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { pub fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, } } fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { match self { ClientMode::Tcp => s.sni_hostname(), ClientMode::Websockets { hostname } => hostname.as_deref(), } } fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> { match self { ClientMode::Tcp => tls, // TLS is None here if using websockets, because the connection is already encrypted. ClientMode::Websockets { .. } => None, } } } #[derive(Debug, Error)] // almost all errors should be reported to the user, but there's a few cases where we cannot // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, // we cannot be sure the client even understands our error message // 3. PrepareClient: The client disconnected, so we can't tell them anyway... pub enum ClientRequestError { #[error("{0}")] Cancellation(#[from] cancellation::CancelError), #[error("{0}")] Handshake(#[from] handshake::HandshakeError), #[error("{0}")] HandshakeTimeout(#[from] tokio::time::error::Elapsed), #[error("{0}")] PrepareClient(#[from] std::io::Error), #[error("{0}")] ReportedError(#[from] crate::stream::ReportedError), } impl ReportableError for ClientRequestError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { ClientRequestError::Cancellation(e) => e.get_error_kind(), ClientRequestError::Handshake(e) => e.get_error_kind(), ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit, ClientRequestError::ReportedError(e) => e.get_error_kind(), ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect, } } } pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, cancellation_handler: Arc, stream: S, mode: ClientMode, conn_gauge: NumClientConnectionsGuard<'static>, ) -> Result>, ClientRequestError> { info!( protocol = %ctx.protocol, "handling interactive connection from client" ); let metrics = &Metrics::get().proxy; let proto = ctx.protocol; // let _client_gauge = metrics.client_connections.guard(proto); let _request_gauge = metrics.connection_requests.guard(proto); let tls = config.tls_config.as_ref(); let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Client); let do_handshake = handshake(stream, mode.handshake_tls(tls), record_handshake_error); let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { return Ok(cancellation_handler .cancel_session(cancel_key_data, ctx.session_id) .await .map(|()| None)?) } }; drop(pause); let hostname = mode.hostname(stream.get_ref()); let common_names = tls.map(|tls| &tls.common_names); // Extract credentials which we're going to use for auth. let result = config .auth_backend .as_ref() .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) .transpose(); let user_info = match result { Ok(user_info) => user_info, Err(e) => stream.throw_error(e).await?, }; let user = user_info.get_user().to_owned(); let user_info = match user_info .authenticate( ctx, &mut stream, mode.allow_cleartext(), &config.authentication_config, ) .await { Ok(auth_result) => auth_result, Err(e) => { let db = params.get("database"); let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); return stream.throw_error(e).instrument(params_span).await?; } }; let mut node = connect_to_compute( ctx, &TcpMechanism { params: ¶ms }, &user_info, mode.allow_self_signed_compute(config), ) .or_else(|e| stream.throw_error(e)) .await?; let session = cancellation_handler.get_session(); prepare_client_connection(&node, &session, &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the // PqStream input buffer. Normally there is none, but our serverless npm // driver in pipeline mode sends startup, password and first query // immediately after opening the connection. let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; Ok(Some(ProxyPassthrough { client: stream, aux: node.aux.clone(), compute: node, req: _request_gauge, conn: conn_gauge, cancel: session, })) } /// Finish client connection initialization: confirm auth success, send params, etc. #[tracing::instrument(skip_all)] async fn prepare_client_connection

( node: &compute::PostgresConnection, session: &cancellation::Session

, stream: &mut PqStream, ) -> Result<(), std::io::Error> { // Register compute's query cancellation token and produce a new, unique one. // The new token (cancel_key_data) will be sent to the client. let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); // Forward all postgres connection params to the client. // Right now the implementation is very hacky and inefficent (ideally, // we don't need an intermediate hashmap), but at least it should be correct. for (name, value) in &node.params { // TODO: Theoretically, this could result in a big pile of params... stream.write_message_noflush(&Be::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), })?; } stream .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? .write_message(&Be::ReadyForQuery) .await?; Ok(()) } #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct NeonOptions(Vec<(SmolStr, SmolStr)>); impl NeonOptions { pub fn parse_params(params: &StartupMessageParams) -> Self { params .options_raw() .map(Self::parse_from_iter) .unwrap_or_default() } pub fn parse_options_raw(options: &str) -> Self { Self::parse_from_iter(StartupMessageParams::parse_options_raw(options)) } pub fn is_ephemeral(&self) -> bool { // Currently, neon endpoint options are all reserved for ephemeral endpoints. !self.0.is_empty() } fn parse_from_iter<'a>(options: impl Iterator) -> Self { let mut options = options .filter_map(neon_option) .map(|(k, v)| (k.into(), v.into())) .collect_vec(); options.sort(); Self(options) } pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { // prefix + format!(" {k}:{v}") // kinda jank because SmolStr is immutable std::iter::once(prefix) .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v])) .collect::() .into() } /// DeepObject format /// `paramName[prop1]=value1¶mName[prop2]=value2&...` pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> { self.0 .iter() .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone())) .collect() } } pub fn neon_option(bytes: &str) -> Option<(&str, &str)> { static RE: OnceCell = OnceCell::new(); let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap()); let cap = re.captures(bytes)?; let (_, [k, v]) = cap.extract(); Some((k, v)) }