diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs index b65fb571e6..0bdad0b554 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_query.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -34,8 +34,13 @@ where .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - let socket = - connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?; + let socket = connect_socket::connect_socket( + config.host_addr, + &config.host, + config.port, + config.connect_timeout, + ) + .await?; cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await } diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 39b1db75da..c70cb598de 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::fmt; +use std::net::IpAddr; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -137,6 +138,7 @@ impl InnerClient { #[derive(Clone, Serialize, Deserialize)] pub struct SocketConfig { + pub host_addr: Option, pub host: Host, pub port: u16, pub connect_timeout: Option, diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 4c25491b67..978d348741 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -1,5 +1,6 @@ //! Connection configuration. +use std::net::IpAddr; use std::time::Duration; use std::{fmt, str}; @@ -65,6 +66,7 @@ pub enum AuthKeys { /// Connection configuration. #[derive(Clone, PartialEq, Eq)] pub struct Config { + pub(crate) host_addr: Option, pub(crate) host: Host, pub(crate) port: u16, @@ -83,6 +85,7 @@ impl Config { /// Creates a new configuration. pub fn new(host: String, port: u16) -> Config { Config { + host_addr: None, host: Host::Tcp(host), port, password: None, @@ -163,6 +166,15 @@ impl Config { self } + pub fn set_host_addr(&mut self, addr: IpAddr) -> &mut Config { + self.host_addr = Some(addr); + self + } + + pub fn get_host_addr(&self) -> Option { + self.host_addr + } + /// Sets the SSL configuration. /// /// Defaults to `prefer`. diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index d2bd0dfbcd..7c3a358bba 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -1,3 +1,5 @@ +use std::net::IpAddr; + use postgres_protocol2::message::backend::Message; use tokio::net::TcpStream; use tokio::sync::mpsc; @@ -25,13 +27,14 @@ where .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; - match connect_once(&config.host, config.port, tls, config).await { + match connect_once(config.host_addr, &config.host, config.port, tls, config).await { Ok((client, connection)) => Ok((client, connection)), Err(e) => Err(e), } } async fn connect_once( + host_addr: Option, host: &Host, port: u16, tls: T, @@ -40,7 +43,7 @@ async fn connect_once( where T: TlsConnect, { - let socket = connect_socket(host, port, config.connect_timeout).await?; + let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?; let RawConnection { stream, parameters, @@ -50,6 +53,7 @@ where } = connect_raw(socket, tls, config).await?; let socket_config = SocketConfig { + host_addr, host: host.clone(), port, connect_timeout: config.connect_timeout, diff --git a/libs/proxy/tokio-postgres2/src/connect_socket.rs b/libs/proxy/tokio-postgres2/src/connect_socket.rs index 15411f7ef3..8c7d300451 100644 --- a/libs/proxy/tokio-postgres2/src/connect_socket.rs +++ b/libs/proxy/tokio-postgres2/src/connect_socket.rs @@ -1,5 +1,6 @@ use std::future::Future; use std::io; +use std::net::{IpAddr, SocketAddr}; use std::time::Duration; use tokio::net::{self, TcpStream}; @@ -9,15 +10,20 @@ use crate::Error; use crate::config::Host; pub(crate) async fn connect_socket( + host_addr: Option, host: &Host, port: u16, connect_timeout: Option, ) -> Result { match host { Host::Tcp(host) => { - let addrs = net::lookup_host((&**host, port)) - .await - .map_err(Error::connect)?; + let addrs = match host_addr { + Some(addr) => vec![SocketAddr::new(addr, port)], + None => net::lookup_host((&**host, port)) + .await + .map_err(Error::connect)? + .collect(), + }; let mut last_err = None; diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 9c3a3772cd..7a6dceb194 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -35,6 +35,7 @@ impl LocalBackend { endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), project_id: ProjectIdTag::get_interner().get_or_intern("local"), branch_id: BranchIdTag::get_interner().get_or_intern("local"), + compute_id: "local".into(), cold_start_info: ColdStartInfo::WarmCached, }, }, diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 5447a4a4c0..2560187608 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::io; use std::net::SocketAddr; use std::time::Duration; @@ -10,7 +11,7 @@ use postgres_protocol::message::backend::NoticeResponseBody; use pq_proto::StartupMessageParams; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; -use tokio::net::TcpStream; +use tokio::net::{TcpStream, lookup_host}; use tracing::{debug, error, info, warn}; use crate::auth::backend::ComputeUserInfo; @@ -180,21 +181,19 @@ impl ConnCfg { use postgres_client::config::Host; // wrap TcpStream::connect with timeout - let connect_with_timeout = |host, port| { - tokio::time::timeout(timeout, TcpStream::connect((host, port))).map( - move |res| match res { - Ok(tcpstream_connect_res) => tcpstream_connect_res, - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - format!("exceeded connection timeout {timeout:?}"), - )), - }, - ) + let connect_with_timeout = |addrs| { + tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res { + Ok(tcpstream_connect_res) => tcpstream_connect_res, + Err(_) => Err(io::Error::new( + io::ErrorKind::TimedOut, + format!("exceeded connection timeout {timeout:?}"), + )), + }) }; - let connect_once = |host, port| { - debug!("trying to connect to compute node at {host}:{port}"); - connect_with_timeout(host, port).and_then(|stream| async { + let connect_once = |addrs| { + debug!("trying to connect to compute node at {addrs:?}"); + connect_with_timeout(addrs).and_then(|stream| async { let socket_addr = stream.peer_addr()?; let socket = socket2::SockRef::from(&stream); // Disable Nagle's algorithm to not introduce latency between @@ -216,7 +215,12 @@ impl ConnCfg { Host::Tcp(host) => host.as_str(), }; - match connect_once(host, port).await { + let addrs = match self.0.get_host_addr() { + Some(addr) => vec![SocketAddr::new(addr, port)], + None => lookup_host((host, port)).await?.collect(), + }; + + match connect_once(&*addrs).await { Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); @@ -277,6 +281,7 @@ impl ConnCfg { } = connection; tracing::Span::current().record("pid", tracing::field::display(process_id)); + tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); let stream = stream.into_inner(); // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?) diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 977fcf4727..2765aaa462 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -1,5 +1,7 @@ //! Production console backend. +use std::net::IpAddr; +use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -274,11 +276,27 @@ impl NeonControlPlaneClient { Some(x) => x, }; + let host_addr = IpAddr::from_str(host).ok(); + + let ssl_mode = match &body.server_name { + Some(_) => SslMode::Require, + None => SslMode::Disable, + }; + let host_name = match body.server_name { + Some(host) => host, + None => host.to_owned(), + }; + // Don't set anything but host and port! This config will be cached. // We'll set username and such later using the startup message. // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(host.to_owned(), port); - config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes. + let mut config = compute::ConnCfg::new(host_name, port); + + if let Some(addr) = host_addr { + config.set_host_addr(addr); + } + + config.ssl_mode(ssl_mode); let node = NodeInfo { config, diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 7da5464aa5..ee722e839e 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -1,5 +1,6 @@ //! Mock console backend which relies on a user-provided postgres instance. +use std::net::{IpAddr, Ipv4Addr}; use std::str::FromStr; use std::sync::Arc; @@ -167,10 +168,22 @@ impl MockControlPlane { } async fn do_wake_compute(&self) -> Result { - let mut config = compute::ConnCfg::new( - self.endpoint.host_str().unwrap_or("localhost").to_owned(), - self.endpoint.port().unwrap_or(5432), - ); + let port = self.endpoint.port().unwrap_or(5432); + let mut config = match self.endpoint.host_str() { + None => { + let mut config = compute::ConnCfg::new("localhost".to_string(), port); + config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST)); + config + } + Some(host) => { + let mut config = compute::ConnCfg::new(host.to_string(), port); + if let Ok(addr) = IpAddr::from_str(host) { + config.set_host_addr(addr); + } + config + } + }; + config.ssl_mode(postgres_client::config::SslMode::Disable); let node = NodeInfo { @@ -179,6 +192,7 @@ impl MockControlPlane { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), branch_id: (&BranchId::from("branch")).into(), + compute_id: "compute".into(), cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm, }, }; diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index 8d6b2e96f5..ec4554eab5 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -2,6 +2,7 @@ use std::fmt::{self, Display}; use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; use crate::auth::IpPattern; use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; @@ -239,6 +240,7 @@ pub(crate) struct GetEndpointAccessControl { #[derive(Debug, Deserialize)] pub(crate) struct WakeCompute { pub(crate) address: Box, + pub(crate) server_name: Option, pub(crate) aux: MetricsAuxInfo, } @@ -312,6 +314,9 @@ pub(crate) struct MetricsAuxInfo { pub(crate) endpoint_id: EndpointIdInt, pub(crate) project_id: ProjectIdInt, pub(crate) branch_id: BranchIdInt, + // note: we don't use interned strings for compute IDs. + // they churn too quickly and we have no way to clean up interned strings. + pub(crate) compute_id: SmolStr, #[serde(default)] pub(crate) cold_start_info: ColdStartInfo, } @@ -378,6 +383,7 @@ mod tests { "endpoint_id": "endpoint", "project_id": "project", "branch_id": "branch", + "compute_id": "compute", "cold_start_info": "unknown", }) } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index b8b39fa121..e013fbbe2e 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -81,7 +81,10 @@ impl ConnectMechanism for TcpMechanism<'_> { type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; - #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + #[tracing::instrument(skip_all, fields( + pid = tracing::field::Empty, + compute_id = tracing::field::Empty + ))] async fn connect_once( &self, ctx: &RequestContext, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 171f539b1e..e0b7539538 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -555,6 +555,7 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), branch_id: (&BranchId::from("branch")).into(), + compute_id: "compute".into(), cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm, }, }; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 72029102e0..b55661cec8 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,4 +1,5 @@ use std::io; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; @@ -6,11 +7,15 @@ use async_trait::async_trait; use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; +use postgres_client::config::SslMode; use rand::rngs::OsRng; +use rustls::pki_types::{DnsName, ServerName}; use tokio::net::{TcpStream, lookup_host}; +use tokio_rustls::TlsConnector; use tracing::field::display; use tracing::{debug, info}; +use super::AsyncRW; use super::conn_pool::poll_client; use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool}; use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; @@ -190,7 +195,11 @@ impl PoolingBackend { // Wake up the destination if needed. Code here is a bit involved because // we reuse the code from the usual proxy and we need to prepare few structures // that this code expects. - #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + #[tracing::instrument(skip_all, fields( + pid = tracing::field::Empty, + compute_id = tracing::field::Empty, + conn_id = tracing::field::Empty, + ))] pub(crate) async fn connect_to_compute( &self, ctx: &RequestContext, @@ -229,7 +238,10 @@ impl PoolingBackend { } // Wake up the destination if needed - #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + #[tracing::instrument(skip_all, fields( + compute_id = tracing::field::Empty, + conn_id = tracing::field::Empty, + ))] pub(crate) async fn connect_to_local_proxy( &self, ctx: &RequestContext, @@ -276,7 +288,10 @@ impl PoolingBackend { /// # Panics /// /// Panics if called with a non-local_proxy backend. - #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + #[tracing::instrument(skip_all, fields( + pid = tracing::field::Empty, + conn_id = tracing::field::Empty, + ))] pub(crate) async fn connect_to_local_postgres( &self, ctx: &RequestContext, @@ -552,6 +567,10 @@ impl ConnectMechanism for TokioMechanism { let (client, connection) = permit.release_result(res)?; tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); + tracing::Span::current().record( + "compute_id", + tracing::field::display(&node_info.aux.compute_id), + ); Ok(poll_client( self.pool.clone(), ctx, @@ -587,16 +606,28 @@ impl ConnectMechanism for HyperMechanism { node_info: &CachedNodeInfo, config: &ComputeConfig, ) -> Result { + let host_addr = node_info.config.get_host_addr(); let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let tls = if node_info.config.get_ssl_mode() == SslMode::Disable { + None + } else { + Some(&config.tls) + }; + let port = node_info.config.get_port(); - let res = connect_http2(&host, port, config.timeout).await; + let res = connect_http2(host_addr, &host, port, config.timeout, tls).await; drop(pause); let (client, connection) = permit.release_result(res)?; + tracing::Span::current().record( + "compute_id", + tracing::field::display(&node_info.aux.compute_id), + ); + Ok(poll_http2_client( self.pool.clone(), ctx, @@ -612,18 +643,22 @@ impl ConnectMechanism for HyperMechanism { } async fn connect_http2( + host_addr: Option, host: &str, port: u16, timeout: Duration, + tls: Option<&Arc>, ) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> { - // assumption: host is an ip address so this should not actually perform any requests. - // todo: add that assumption as a guarantee in the control-plane API. - let mut addrs = lookup_host((host, port)) - .await - .map_err(LocalProxyConnError::Io)?; - + let addrs = match host_addr { + Some(addr) => vec![SocketAddr::new(addr, port)], + None => lookup_host((host, port)) + .await + .map_err(LocalProxyConnError::Io)? + .collect(), + }; let mut last_err = None; + let mut addrs = addrs.into_iter(); let stream = loop { let Some(addr) = addrs.next() else { return Err(last_err.unwrap_or_else(|| { @@ -651,6 +686,20 @@ async fn connect_http2( } }; + let stream = if let Some(tls) = tls { + let host = DnsName::try_from(host) + .map_err(io::Error::other) + .map_err(LocalProxyConnError::Io)? + .to_owned(); + let stream = TlsConnector::from(tls.clone()) + .connect(ServerName::DnsName(host), stream) + .await + .map_err(LocalProxyConnError::Io)?; + Box::pin(stream) as AsyncRW + } else { + Box::pin(stream) as AsyncRW + }; + let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) .timer(TokioTimer::new()) .keep_alive_interval(Duration::from_secs(20)) diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 6a9089fc2a..516d474a11 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -221,6 +221,7 @@ mod tests { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), branch_id: (&BranchId::from("branch")).into(), + compute_id: "compute".into(), cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm, }, conn_id: uuid::Uuid::new_v4(), diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 338a79b4b3..bca2d4c165 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -6,9 +6,9 @@ use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; use smol_str::ToSmolStr; -use tokio::net::TcpStream; use tracing::{Instrument, debug, error, info, info_span}; +use super::AsyncRW; use super::backend::HttpConnError; use super::conn_pool_lib::{ ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry, @@ -22,8 +22,7 @@ use crate::types::EndpointCacheKey; use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS}; pub(crate) type Send = http2::SendRequest; -pub(crate) type Connect = - http2::Connection, hyper::body::Incoming, TokioExecutor>; +pub(crate) type Connect = http2::Connection, hyper::body::Incoming, TokioExecutor>; #[derive(Clone)] pub(crate) struct ClientDataHttp(); diff --git a/proxy/src/tls/client_config.rs b/proxy/src/tls/client_config.rs index a2d695aae1..ce873e678e 100644 --- a/proxy/src/tls/client_config.rs +++ b/proxy/src/tls/client_config.rs @@ -1,17 +1,49 @@ +use std::env; +use std::io::Cursor; +use std::path::PathBuf; use std::sync::Arc; -use anyhow::bail; +use anyhow::{Context, bail}; use rustls::crypto::ring; -pub(crate) fn load_certs() -> anyhow::Result> { +/// We use an internal certificate authority when establishing a TLS connection with compute. +fn load_internal_certs(store: &mut rustls::RootCertStore) -> anyhow::Result<()> { + let Some(ca_file) = env::var_os("NEON_INTERNAL_CA_FILE") else { + return Ok(()); + }; + let ca_file = PathBuf::from(ca_file); + + let ca = std::fs::read(&ca_file) + .with_context(|| format!("could not read CA from {}", ca_file.display()))?; + + for cert in rustls_pemfile::certs(&mut Cursor::new(&*ca)) { + store + .add(cert.context("could not parse internal CA certificate")?) + .context("could not parse internal CA certificate")?; + } + + Ok(()) +} + +/// For console redirect proxy, we need to establish a connection to compute via pg-sni-router. +/// pg-sni-router needs TLS and uses a Let's Encrypt signed certificate, so we +/// load certificates from our native store. +fn load_native_certs(store: &mut rustls::RootCertStore) -> anyhow::Result<()> { let der_certs = rustls_native_certs::load_native_certs(); if !der_certs.errors.is_empty() { bail!("could not parse certificates: {:?}", der_certs.errors); } - let mut store = rustls::RootCertStore::empty(); store.add_parsable_certificates(der_certs.certs); + + Ok(()) +} + +fn load_compute_certs() -> anyhow::Result> { + let mut store = rustls::RootCertStore::empty(); + load_native_certs(&mut store)?; + load_internal_certs(&mut store)?; Ok(Arc::new(store)) } @@ -22,7 +54,7 @@ pub fn compute_client_config_with_root_certs() -> anyhow::Result