mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-09 14:32:57 +00:00
feat(proxy): require TLS to compute if prompted by cplane (#10717)
https://github.com/neondatabase/cloud/issues/23008 For TLS between proxy and compute, we are using an internally provisioned CA to sign the compute certificates. This change ensures that proxy will load them from a supplied env var pointing to the correct file - this file and env var will be configured later, using a kubernetes secret. Control plane responds with a `server_name` field if and only if the compute uses TLS. This server name is the name we use to validate the certificate. Control plane still sends us the IP to connect to as well (to support overlay IP). To support this change, I'd had to split `host` and `host_addr` into separate fields. Using `host_addr` and bypassing `lookup_addr` if possible (which is what happens in production). `host` then is only used for the TLS connection. There's no blocker to merging this. The code paths will not be triggered until the new control plane is deployed and the `enableTLS` compute flag is enabled on a project.
This commit is contained in:
@@ -34,8 +34,13 @@ where
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
let socket =
|
||||
connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?;
|
||||
let socket = connect_socket::connect_socket(
|
||||
config.host_addr,
|
||||
&config.host,
|
||||
config.port,
|
||||
config.connect_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
@@ -137,6 +138,7 @@ impl InnerClient {
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct SocketConfig {
|
||||
pub host_addr: Option<IpAddr>,
|
||||
pub host: Host,
|
||||
pub port: u16,
|
||||
pub connect_timeout: Option<Duration>,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
//! Connection configuration.
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, str};
|
||||
|
||||
@@ -65,6 +66,7 @@ pub enum AuthKeys {
|
||||
/// Connection configuration.
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Config {
|
||||
pub(crate) host_addr: Option<IpAddr>,
|
||||
pub(crate) host: Host,
|
||||
pub(crate) port: u16,
|
||||
|
||||
@@ -83,6 +85,7 @@ impl Config {
|
||||
/// Creates a new configuration.
|
||||
pub fn new(host: String, port: u16) -> Config {
|
||||
Config {
|
||||
host_addr: None,
|
||||
host: Host::Tcp(host),
|
||||
port,
|
||||
password: None,
|
||||
@@ -163,6 +166,15 @@ impl Config {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_host_addr(&mut self, addr: IpAddr) -> &mut Config {
|
||||
self.host_addr = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn get_host_addr(&self) -> Option<IpAddr> {
|
||||
self.host_addr
|
||||
}
|
||||
|
||||
/// Sets the SSL configuration.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -25,13 +27,14 @@ where
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
match connect_once(&config.host, config.port, tls, config).await {
|
||||
match connect_once(config.host_addr, &config.host, config.port, tls, config).await {
|
||||
Ok((client, connection)) => Ok((client, connection)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_once<T>(
|
||||
host_addr: Option<IpAddr>,
|
||||
host: &Host,
|
||||
port: u16,
|
||||
tls: T,
|
||||
@@ -40,7 +43,7 @@ async fn connect_once<T>(
|
||||
where
|
||||
T: TlsConnect<TcpStream>,
|
||||
{
|
||||
let socket = connect_socket(host, port, config.connect_timeout).await?;
|
||||
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters,
|
||||
@@ -50,6 +53,7 @@ where
|
||||
} = connect_raw(socket, tls, config).await?;
|
||||
|
||||
let socket_config = SocketConfig {
|
||||
host_addr,
|
||||
host: host.clone(),
|
||||
port,
|
||||
connect_timeout: config.connect_timeout,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::net::{self, TcpStream};
|
||||
@@ -9,15 +10,20 @@ use crate::Error;
|
||||
use crate::config::Host;
|
||||
|
||||
pub(crate) async fn connect_socket(
|
||||
host_addr: Option<IpAddr>,
|
||||
host: &Host,
|
||||
port: u16,
|
||||
connect_timeout: Option<Duration>,
|
||||
) -> Result<TcpStream, Error> {
|
||||
match host {
|
||||
Host::Tcp(host) => {
|
||||
let addrs = net::lookup_host((&**host, port))
|
||||
.await
|
||||
.map_err(Error::connect)?;
|
||||
let addrs = match host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => net::lookup_host((&**host, port))
|
||||
.await
|
||||
.map_err(Error::connect)?
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ impl LocalBackend {
|
||||
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
|
||||
project_id: ProjectIdTag::get_interner().get_or_intern("local"),
|
||||
branch_id: BranchIdTag::get_interner().get_or_intern("local"),
|
||||
compute_id: "local".into(),
|
||||
cold_start_info: ColdStartInfo::WarmCached,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::fmt::Debug;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
@@ -10,7 +11,7 @@ use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
@@ -180,21 +181,19 @@ impl ConnCfg {
|
||||
use postgres_client::config::Host;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
let connect_with_timeout = |host, port| {
|
||||
tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
|
||||
move |res| match res {
|
||||
Ok(tcpstream_connect_res) => tcpstream_connect_res,
|
||||
Err(_) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
format!("exceeded connection timeout {timeout:?}"),
|
||||
)),
|
||||
},
|
||||
)
|
||||
let connect_with_timeout = |addrs| {
|
||||
tokio::time::timeout(timeout, TcpStream::connect(addrs)).map(move |res| match res {
|
||||
Ok(tcpstream_connect_res) => tcpstream_connect_res,
|
||||
Err(_) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
format!("exceeded connection timeout {timeout:?}"),
|
||||
)),
|
||||
})
|
||||
};
|
||||
|
||||
let connect_once = |host, port| {
|
||||
debug!("trying to connect to compute node at {host}:{port}");
|
||||
connect_with_timeout(host, port).and_then(|stream| async {
|
||||
let connect_once = |addrs| {
|
||||
debug!("trying to connect to compute node at {addrs:?}");
|
||||
connect_with_timeout(addrs).and_then(|stream| async {
|
||||
let socket_addr = stream.peer_addr()?;
|
||||
let socket = socket2::SockRef::from(&stream);
|
||||
// Disable Nagle's algorithm to not introduce latency between
|
||||
@@ -216,7 +215,12 @@ impl ConnCfg {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
|
||||
match connect_once(host, port).await {
|
||||
let addrs = match self.0.get_host_addr() {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => lookup_host((host, port)).await?.collect(),
|
||||
};
|
||||
|
||||
match connect_once(&*addrs).await {
|
||||
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
|
||||
Err(err) => {
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
@@ -277,6 +281,7 @@ impl ConnCfg {
|
||||
} = connection;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
|
||||
let stream = stream.into_inner();
|
||||
|
||||
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
//! Production console backend.
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -274,11 +276,27 @@ impl NeonControlPlaneClient {
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
let host_addr = IpAddr::from_str(host).ok();
|
||||
|
||||
let ssl_mode = match &body.server_name {
|
||||
Some(_) => SslMode::Require,
|
||||
None => SslMode::Disable,
|
||||
};
|
||||
let host_name = match body.server_name {
|
||||
Some(host) => host,
|
||||
None => host.to_owned(),
|
||||
};
|
||||
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new(host.to_owned(), port);
|
||||
config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
|
||||
let mut config = compute::ConnCfg::new(host_name, port);
|
||||
|
||||
if let Some(addr) = host_addr {
|
||||
config.set_host_addr(addr);
|
||||
}
|
||||
|
||||
config.ssl_mode(ssl_mode);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
//! Mock console backend which relies on a user-provided postgres instance.
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -167,10 +168,22 @@ impl MockControlPlane {
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new(
|
||||
self.endpoint.host_str().unwrap_or("localhost").to_owned(),
|
||||
self.endpoint.port().unwrap_or(5432),
|
||||
);
|
||||
let port = self.endpoint.port().unwrap_or(5432);
|
||||
let mut config = match self.endpoint.host_str() {
|
||||
None => {
|
||||
let mut config = compute::ConnCfg::new("localhost".to_string(), port);
|
||||
config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
|
||||
config
|
||||
}
|
||||
Some(host) => {
|
||||
let mut config = compute::ConnCfg::new(host.to_string(), port);
|
||||
if let Ok(addr) = IpAddr::from_str(host) {
|
||||
config.set_host_addr(addr);
|
||||
}
|
||||
config
|
||||
}
|
||||
};
|
||||
|
||||
config.ssl_mode(postgres_client::config::SslMode::Disable);
|
||||
|
||||
let node = NodeInfo {
|
||||
@@ -179,6 +192,7 @@ impl MockControlPlane {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
compute_id: "compute".into(),
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::fmt::{self, Display};
|
||||
|
||||
use measured::FixedCardinalityLabel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::SmolStr;
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
@@ -239,6 +240,7 @@ pub(crate) struct GetEndpointAccessControl {
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct WakeCompute {
|
||||
pub(crate) address: Box<str>,
|
||||
pub(crate) server_name: Option<String>,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
@@ -312,6 +314,9 @@ pub(crate) struct MetricsAuxInfo {
|
||||
pub(crate) endpoint_id: EndpointIdInt,
|
||||
pub(crate) project_id: ProjectIdInt,
|
||||
pub(crate) branch_id: BranchIdInt,
|
||||
// note: we don't use interned strings for compute IDs.
|
||||
// they churn too quickly and we have no way to clean up interned strings.
|
||||
pub(crate) compute_id: SmolStr,
|
||||
#[serde(default)]
|
||||
pub(crate) cold_start_info: ColdStartInfo,
|
||||
}
|
||||
@@ -378,6 +383,7 @@ mod tests {
|
||||
"endpoint_id": "endpoint",
|
||||
"project_id": "project",
|
||||
"branch_id": "branch",
|
||||
"compute_id": "compute",
|
||||
"cold_start_info": "unknown",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -81,7 +81,10 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
type ConnectError = compute::ConnectionError;
|
||||
type Error = compute::ConnectionError;
|
||||
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
#[tracing::instrument(skip_all, fields(
|
||||
pid = tracing::field::Empty,
|
||||
compute_id = tracing::field::Empty
|
||||
))]
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
|
||||
@@ -555,6 +555,7 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
compute_id: "compute".into(),
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::io;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -6,11 +7,15 @@ use async_trait::async_trait;
|
||||
use ed25519_dalek::SigningKey;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use jose_jwk::jose_b64;
|
||||
use postgres_client::config::SslMode;
|
||||
use rand::rngs::OsRng;
|
||||
use rustls::pki_types::{DnsName, ServerName};
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tracing::field::display;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::AsyncRW;
|
||||
use super::conn_pool::poll_client;
|
||||
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
|
||||
@@ -190,7 +195,11 @@ impl PoolingBackend {
|
||||
// Wake up the destination if needed. Code here is a bit involved because
|
||||
// we reuse the code from the usual proxy and we need to prepare few structures
|
||||
// that this code expects.
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
#[tracing::instrument(skip_all, fields(
|
||||
pid = tracing::field::Empty,
|
||||
compute_id = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty,
|
||||
))]
|
||||
pub(crate) async fn connect_to_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
@@ -229,7 +238,10 @@ impl PoolingBackend {
|
||||
}
|
||||
|
||||
// Wake up the destination if needed
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
#[tracing::instrument(skip_all, fields(
|
||||
compute_id = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty,
|
||||
))]
|
||||
pub(crate) async fn connect_to_local_proxy(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
@@ -276,7 +288,10 @@ impl PoolingBackend {
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called with a non-local_proxy backend.
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
#[tracing::instrument(skip_all, fields(
|
||||
pid = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty,
|
||||
))]
|
||||
pub(crate) async fn connect_to_local_postgres(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
@@ -552,6 +567,10 @@ impl ConnectMechanism for TokioMechanism {
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
tracing::Span::current().record(
|
||||
"compute_id",
|
||||
tracing::field::display(&node_info.aux.compute_id),
|
||||
);
|
||||
Ok(poll_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
@@ -587,16 +606,28 @@ impl ConnectMechanism for HyperMechanism {
|
||||
node_info: &CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host_addr = node_info.config.get_host_addr();
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
|
||||
None
|
||||
} else {
|
||||
Some(&config.tls)
|
||||
};
|
||||
|
||||
let port = node_info.config.get_port();
|
||||
let res = connect_http2(&host, port, config.timeout).await;
|
||||
let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
tracing::Span::current().record(
|
||||
"compute_id",
|
||||
tracing::field::display(&node_info.aux.compute_id),
|
||||
);
|
||||
|
||||
Ok(poll_http2_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
@@ -612,18 +643,22 @@ impl ConnectMechanism for HyperMechanism {
|
||||
}
|
||||
|
||||
async fn connect_http2(
|
||||
host_addr: Option<IpAddr>,
|
||||
host: &str,
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
tls: Option<&Arc<rustls::ClientConfig>>,
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||
// assumption: host is an ip address so this should not actually perform any requests.
|
||||
// todo: add that assumption as a guarantee in the control-plane API.
|
||||
let mut addrs = lookup_host((host, port))
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?;
|
||||
|
||||
let addrs = match host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => lookup_host((host, port))
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?
|
||||
.collect(),
|
||||
};
|
||||
let mut last_err = None;
|
||||
|
||||
let mut addrs = addrs.into_iter();
|
||||
let stream = loop {
|
||||
let Some(addr) = addrs.next() else {
|
||||
return Err(last_err.unwrap_or_else(|| {
|
||||
@@ -651,6 +686,20 @@ async fn connect_http2(
|
||||
}
|
||||
};
|
||||
|
||||
let stream = if let Some(tls) = tls {
|
||||
let host = DnsName::try_from(host)
|
||||
.map_err(io::Error::other)
|
||||
.map_err(LocalProxyConnError::Io)?
|
||||
.to_owned();
|
||||
let stream = TlsConnector::from(tls.clone())
|
||||
.connect(ServerName::DnsName(host), stream)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?;
|
||||
Box::pin(stream) as AsyncRW
|
||||
} else {
|
||||
Box::pin(stream) as AsyncRW
|
||||
};
|
||||
|
||||
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(Duration::from_secs(20))
|
||||
|
||||
@@ -221,6 +221,7 @@ mod tests {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
compute_id: "compute".into(),
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
conn_id: uuid::Uuid::new_v4(),
|
||||
|
||||
@@ -6,9 +6,9 @@ use hyper::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use smol_str::ToSmolStr;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{Instrument, debug, error, info, info_span};
|
||||
|
||||
use super::AsyncRW;
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool_lib::{
|
||||
ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry,
|
||||
@@ -22,8 +22,7 @@ use crate::types::EndpointCacheKey;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS};
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
|
||||
pub(crate) type Connect =
|
||||
http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
|
||||
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ClientDataHttp();
|
||||
|
||||
@@ -1,17 +1,49 @@
|
||||
use std::env;
|
||||
use std::io::Cursor;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::{Context, bail};
|
||||
use rustls::crypto::ring;
|
||||
|
||||
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
|
||||
/// We use an internal certificate authority when establishing a TLS connection with compute.
|
||||
fn load_internal_certs(store: &mut rustls::RootCertStore) -> anyhow::Result<()> {
|
||||
let Some(ca_file) = env::var_os("NEON_INTERNAL_CA_FILE") else {
|
||||
return Ok(());
|
||||
};
|
||||
let ca_file = PathBuf::from(ca_file);
|
||||
|
||||
let ca = std::fs::read(&ca_file)
|
||||
.with_context(|| format!("could not read CA from {}", ca_file.display()))?;
|
||||
|
||||
for cert in rustls_pemfile::certs(&mut Cursor::new(&*ca)) {
|
||||
store
|
||||
.add(cert.context("could not parse internal CA certificate")?)
|
||||
.context("could not parse internal CA certificate")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// For console redirect proxy, we need to establish a connection to compute via pg-sni-router.
|
||||
/// pg-sni-router needs TLS and uses a Let's Encrypt signed certificate, so we
|
||||
/// load certificates from our native store.
|
||||
fn load_native_certs(store: &mut rustls::RootCertStore) -> anyhow::Result<()> {
|
||||
let der_certs = rustls_native_certs::load_native_certs();
|
||||
|
||||
if !der_certs.errors.is_empty() {
|
||||
bail!("could not parse certificates: {:?}", der_certs.errors);
|
||||
}
|
||||
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs.certs);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_compute_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
load_native_certs(&mut store)?;
|
||||
load_internal_certs(&mut store)?;
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
|
||||
@@ -22,7 +54,7 @@ pub fn compute_client_config_with_root_certs() -> anyhow::Result<rustls::ClientC
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(load_certs()?)
|
||||
.with_root_certificates(load_compute_certs()?)
|
||||
.with_no_client_auth(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -3601,6 +3601,7 @@ class NeonProxy(PgProtocol):
|
||||
"project_id": "test_project_id",
|
||||
"endpoint_id": "test_endpoint_id",
|
||||
"branch_id": "test_branch_id",
|
||||
"compute_id": "test_compute_id",
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -3826,6 +3827,7 @@ def static_auth_broker(
|
||||
{
|
||||
"address": local_proxy_addr,
|
||||
"aux": {
|
||||
"compute_id": "compute-foo-bar-1234-5678",
|
||||
"endpoint_id": "ep-foo-bar-1234",
|
||||
"branch_id": "br-foo-bar",
|
||||
"project_id": "foo-bar",
|
||||
|
||||
Reference in New Issue
Block a user