Compare commits

..

1 Commits

Author SHA1 Message Date
Stas Kelvich
5325f1e9d0 Set libpq path for pg_dump in fast_import 2024-12-20 13:49:21 +02:00
35 changed files with 595 additions and 668 deletions

View File

@@ -212,7 +212,7 @@ jobs:
fi
echo "CLIPPY_COMMON_ARGS=${CLIPPY_COMMON_ARGS}" >> $GITHUB_ENV
- name: Run cargo clippy (debug)
run: cargo hack --features default --ignore-unknown-features --feature-powerset clippy $CLIPPY_COMMON_ARGS
run: cargo hack --feature-powerset clippy $CLIPPY_COMMON_ARGS
- name: Check documentation generation
run: cargo doc --workspace --no-deps --document-private-items

View File

@@ -1285,7 +1285,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \
#########################################################################################
#
# Compile the Neon-specific `compute_ctl`, `fast_import`, and `local_proxy` binaries
# Compile and run the Neon-specific `compute_ctl` and `fast_import` binaries
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools
@@ -1295,7 +1295,7 @@ ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy
RUN cd compute_tools && mold -run cargo build --locked --profile release-line-debug-size-lto
#########################################################################################
#
@@ -1338,6 +1338,20 @@ RUN set -e \
&& make -j $(nproc) dist_man_MANS= \
&& make install dist_man_MANS=
#########################################################################################
#
# Compile the Neon-specific `local_proxy` binary
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS local_proxy
ARG BUILD_TAG
ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin local_proxy
#########################################################################################
#
# Layers "postgres-exporter" and "sql-exporter"
@@ -1477,7 +1491,7 @@ COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
# local_proxy and its config
COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
COPY --from=local_proxy --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
RUN mkdir -p /etc/local_proxy && chown postgres:postgres /etc/local_proxy
# Metrics exporter binaries and configuration files

View File

@@ -172,6 +172,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
.args(["-c", &format!("max_worker_processes={nproc}")])
.args(["-c", "effective_io_concurrency=100"])
.env_clear()
.env("LD_LIBRARY_PATH", "/usr/local/lib/")
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
@@ -256,6 +257,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
.arg(&source_connection_string)
// how we run it
.env_clear()
.env("LD_LIBRARY_PATH", "/usr/local/lib/")
.kill_on_drop(true)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
@@ -289,6 +291,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
.arg(&dumpdir)
// how we run it
.env_clear()
.env("LD_LIBRARY_PATH", "/usr/local/lib/")
.kill_on_drop(true)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())

View File

@@ -1,9 +1,11 @@
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
use postgres_protocol2::message::backend::Message;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
@@ -41,7 +43,7 @@ where
let RawConnection {
stream,
parameters,
delayed_notice: _,
delayed_notice,
process_id,
secret_key,
} = connect_raw(socket, tls, config).await?;
@@ -61,7 +63,13 @@ where
secret_key,
);
let connection = Connection::new(stream, parameters, receiver);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, receiver);
Ok((client, connection))
}

View File

@@ -1,10 +1,11 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, Stream};
use log::trace;
use log::{info, trace};
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
@@ -32,12 +33,8 @@ pub struct Response {
#[derive(PartialEq, Debug)]
enum State {
Active,
Closing,
}
enum WriteReady {
Terminating,
WaitingOnRead,
Closing,
}
/// A connection to a PostgreSQL database.
@@ -54,7 +51,8 @@ pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_responses: Option<BackendMessage>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
@@ -66,6 +64,7 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
@@ -73,7 +72,8 @@ where
stream,
parameters,
receiver,
pending_responses: None,
pending_request: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
}
@@ -83,7 +83,7 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.take() {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
@@ -93,22 +93,26 @@ where
.map(|o| o.map(|r| r.map_err(Error::io)))
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Ready(None) => return Err(Error::closed()),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Poll::Pending;
return Ok(None);
}
};
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(_)) => {
continue;
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Ok(Some(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
@@ -116,7 +120,7 @@ where
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
return Ok(Some(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
@@ -135,10 +139,8 @@ where
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => {
return Poll::Ready(Err(Error::db(error)))
}
_ => return Poll::Ready(Err(Error::unexpected_message())),
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
},
};
@@ -157,19 +159,23 @@ where
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_responses = Some(BackendMessage::Normal {
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});
trace!("poll_read: waiting on sender");
return Poll::Pending;
return Ok(None);
}
}
}
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
if self.receiver.is_closed() {
return Poll::Ready(None);
}
@@ -187,80 +193,74 @@ where
}
}
/// Process client requests and write them to the postgres connection, flushing if necessary.
/// client -> postgres
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
if Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
.is_pending()
{
trace!("poll_write: waiting on socket");
// poll_ready is self-flushing.
return Poll::Pending;
return Ok(false);
}
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(RequestMessages::Single(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) if self.responses.is_empty() => {
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = BytesMut::new();
frontend::terminate(&mut request);
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
trace!("poll_write: done");
return Poll::Ready(Ok(WriteReady::Terminating));
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
}
// No more messages from the client, but there are still some responses to wait for.
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
ready!(self.poll_flush(cx))?;
return Poll::Ready(Ok(WriteReady::WaitingOnRead));
return Ok(true);
}
// Still waiting for a message from the client.
Poll::Pending => {
trace!("poll_write: waiting on request");
ready!(self.poll_flush(cx))?;
return Poll::Pending;
return Ok(true);
}
};
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => {
trace!("poll_flush: flushed");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_flush: waiting on socket");
Poll::Pending
}
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
@@ -289,30 +289,18 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(WriteReady::Terminating) = closing {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
}
match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
}
}
@@ -325,7 +313,11 @@ where
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while ready!(self.poll_message(cx)?).is_some() {}
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
}
Poll::Ready(Ok(()))
}
}

View File

@@ -6,6 +6,7 @@ pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
@@ -99,6 +100,10 @@ impl Notification {
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.

View File

@@ -541,7 +541,6 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
LWLockRelease(lfc_lock);
return found;
}

View File

@@ -10,6 +10,7 @@ use tracing::info;
use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::config::TlsServerEndPoint;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
@@ -17,7 +18,6 @@ use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {

View File

@@ -13,9 +13,7 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
@@ -27,7 +25,6 @@ use proxy::rate_limiter::{
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::{self, GlobalConnPoolOptions};
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::types::RoleName;
use proxy::url::ApiUrl;
@@ -212,7 +209,6 @@ async fn main() -> anyhow::Result<()> {
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
&config.connect_to_compute,
Arc::new(DashMap::new()),
None,
proxy::metrics::CancellationSource::Local,
@@ -272,12 +268,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let compute_config = ComputeConfig {
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
timeout: Duration::from_secs(2),
};
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
metric_collection: None,
@@ -299,7 +289,9 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,
connect_to_compute_retry_config: RetryConfig::parse(
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
)?,
})))
}

View File

@@ -10,12 +10,12 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy::stream::{PqStream, Stream};
use proxy::tls::TlsServerEndPoint;
use rustls::crypto::ring;
use rustls::pki_types::PrivateKeyDer;
use tokio::io::{AsyncRead, AsyncWrite};

View File

@@ -1,7 +1,6 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use futures::future::Either;
@@ -9,7 +8,7 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::context::parquet::ParquetUploadArgs;
@@ -24,7 +23,6 @@ use proxy::redis::{elasticache, notifications};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
@@ -399,7 +397,6 @@ async fn main() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
&config.connect_to_compute,
cancel_map.clone(),
redis_publisher,
proxy::metrics::CancellationSource::FromClient,
@@ -495,7 +492,6 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -504,7 +500,6 @@ async fn main() -> anyhow::Result<()> {
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -637,12 +632,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
};
let compute_config = ComputeConfig {
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
timeout: Duration::from_secs(2),
};
let config = ProxyConfig {
tls_config,
metric_collection,
@@ -653,7 +642,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
};
let config = Box::leak(Box::new(config));

View File

@@ -3,9 +3,11 @@ use std::sync::Arc;
use dashmap::DashMap;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::CancelToken;
use pq_proto::CancelKeyData;
use rustls::crypto::ring;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
@@ -13,15 +15,15 @@ use tracing::{debug, info};
use uuid::Uuid;
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::config::ComputeConfig;
use crate::compute::load_certs;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
use crate::tls::postgres_rustls::MakeRustlsConnect;
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
@@ -33,7 +35,6 @@ type IpSubnetKey = IpNet;
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler<P> {
compute_config: &'static ComputeConfig,
map: CancelMap,
client: P,
/// This field used for the monitoring purposes.
@@ -182,7 +183,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
"cancelling query per user's request using key {key}, hostname {}, address: {}",
cancel_closure.hostname, cancel_closure.socket_addr
);
cancel_closure.try_cancel_query(self.compute_config).await
cancel_closure.try_cancel_query().await
}
#[cfg(test)]
@@ -197,13 +198,8 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
impl CancellationHandler<()> {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
from: CancellationSource,
) -> Self {
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
Self {
compute_config,
map,
client: (),
from,
@@ -218,14 +214,8 @@ impl CancellationHandler<()> {
}
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
client: Option<Arc<Mutex<P>>>,
from: CancellationSource,
) -> Self {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
Self {
compute_config,
map,
client,
from,
@@ -239,6 +229,8 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
}
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
@@ -265,14 +257,27 @@ impl CancelClosure {
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(
self,
compute_config: &ComputeConfig,
) -> Result<(), CancelError> {
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(|_e| {
CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
"TLS root store initialization failed".to_string(),
))
})?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&self.hostname,
@@ -324,30 +329,11 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
#[expect(clippy::unwrap_used)]
mod tests {
use std::time::Duration;
use super::*;
use crate::config::RetryConfig;
use crate::tls::client_config::compute_client_config_with_certs;
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::FromRedis,
));
@@ -363,11 +349,8 @@ mod tests {
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler = CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::Local,
);
let handler =
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
handler
.cancel_session(
CancelKeyData {

View File

@@ -1,13 +1,16 @@
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::crypto::ring;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
@@ -15,15 +18,14 @@ use tracing::{debug, error, info, warn};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -38,6 +40,9 @@ pub(crate) enum ConnectionError {
#[error("{COULD_NOT_CONNECT}: {0}")]
CouldNotConnect(#[from] io::Error),
#[error("Couldn't load native TLS certificates: {0:?}")]
TlsCertificateError(Vec<rustls_native_certs::Error>),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
@@ -84,6 +89,7 @@ impl ReportableError for ConnectionError {
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -193,15 +199,11 @@ impl ConnCfg {
let connect_once = |host, port| {
debug!("trying to connect to compute node at {host}:{port}");
connect_with_timeout(host, port).and_then(|stream| async {
let socket_addr = stream.peer_addr()?;
let socket = socket2::SockRef::from(&stream);
// Disable Nagle's algorithm to not introduce latency between
// client and compute.
socket.set_nodelay(true)?;
connect_with_timeout(host, port).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
// This prevents load balancer from severing the connection.
socket.set_keepalive(true)?;
Ok((socket_addr, stream))
socket2::SockRef::from(&socket).set_keepalive(true)?;
Ok((socket_addr, socket))
})
};
@@ -249,13 +251,25 @@ impl ConnCfg {
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
config: &ComputeConfig,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
drop(pause);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(ConnectionError::TlsCertificateError)?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
@@ -327,6 +341,19 @@ fn filtered_options(options: &str) -> Option<String> {
Some(options)
}
pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
return Err(der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,10 +1,17 @@
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::AuthRateLimiter;
@@ -13,7 +20,6 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::GlobalConnPoolOptions;
pub use crate::tls::server_config::{configure_tls, TlsConfig};
use crate::types::Host;
pub struct ProxyConfig {
@@ -26,13 +32,7 @@ pub struct ProxyConfig {
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute: ComputeConfig,
}
pub struct ComputeConfig {
pub retry: RetryConfig,
pub tls: Arc<rustls::ClientConfig>,
pub timeout: Duration,
pub connect_to_compute_retry_config: RetryConfig,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
@@ -52,6 +52,12 @@ pub struct MetricCollectionConfig {
pub backup_metric_collection_config: MetricBackupCollectionConfig,
}
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
pub struct HttpConfig {
pub accept_websockets: bool,
pub pool_options: GlobalConnPoolOptions,
@@ -74,6 +80,272 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.

View File

@@ -115,7 +115,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -216,7 +216,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
config.connect_to_compute_retry_config,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -10,13 +10,13 @@ pub mod client;
pub(crate) mod errors;
use std::sync::Arc;
use std::time::Duration;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::IpPattern;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::ProjectIdInt;
@@ -73,9 +73,9 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
config: &ComputeConfig,
timeout: Duration,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config.connect(ctx, self.aux.clone(), config).await
self.config.connect(ctx, self.aux.clone(), timeout).await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {

View File

@@ -89,6 +89,7 @@ pub mod jemalloc;
pub mod logging;
pub mod metrics;
pub mod parse;
pub mod postgres_rustls;
pub mod protocol2;
pub mod proxy;
pub mod rate_limiter;
@@ -98,7 +99,6 @@ pub mod scram;
pub mod serverless;
pub mod signals;
pub mod stream;
pub mod tls;
pub mod types;
pub mod url;
pub mod usage_metrics;

View File

@@ -18,7 +18,7 @@ mod private {
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use crate::tls::TlsServerEndPoint;
use crate::config::TlsServerEndPoint;
pub struct TlsConnectFuture<S> {
inner: tokio_rustls::Connect<S>,
@@ -126,14 +126,16 @@ mod private {
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
pub fn new(config: ClientConfig) -> Self {
Self {
config: Arc::new(config),
}
}
}

View File

@@ -6,7 +6,7 @@ use tracing::{debug, info, warn};
use super::retry::ShouldRetryWakeCompute;
use crate::auth::backend::ComputeCredentialKeys;
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::RetryConfig;
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
@@ -19,6 +19,8 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
@@ -47,7 +49,7 @@ pub(crate) trait ConnectMechanism {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
@@ -84,11 +86,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, config).await)
permit.release_result(node_info.connect(ctx, timeout).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -103,7 +105,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBac
mechanism: &M,
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
connect_to_compute_retry_config: RetryConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
@@ -117,7 +119,10 @@ where
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
let err = match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -137,7 +142,7 @@ where
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if should_retry(&err, num_retries, compute.retry) {
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,
@@ -167,7 +172,10 @@ where
debug!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism.connect_once(ctx, &node_info, compute).await {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -182,7 +190,7 @@ where
return Ok(res);
}
Err(e) => {
if !should_retry(&e, num_retries, compute.retry) {
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
// Don't log an error here, caller will print the error
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
@@ -198,7 +206,7 @@ where
}
};
let wait_duration = retry_after(num_retries, compute.retry);
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
num_retries += 1;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);

View File

@@ -8,13 +8,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
use crate::config::TlsConfig;
use crate::config::{TlsConfig, PG_ALPN_PROTOCOL};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;
#[derive(Error, Debug)]
pub(crate) enum HandshakeError {

View File

@@ -152,7 +152,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -351,7 +351,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
config.connect_to_compute_retry_config,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -5,7 +5,6 @@ use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
@@ -68,17 +67,9 @@ pub(crate) struct ProxyPassthrough<P, S> {
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
res

View File

@@ -22,16 +22,14 @@ use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{CertResolver, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
};
use crate::error::ErrorKind;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::postgres_rustls::MakeRustlsConnect;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -69,7 +67,7 @@ fn generate_certs(
}
struct ClientConfig<'a> {
config: Arc<rustls::ClientConfig>,
config: rustls::ClientConfig,
hostname: &'a str,
}
@@ -112,7 +110,16 @@ fn generate_tls_config<'a>(
};
let client_config = {
let config = Arc::new(compute_client_config_with_certs([ca]));
let config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.context("ring should support the default protocol versions")?
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(ca)?;
store
})
.with_no_client_auth();
ClientConfig { config, hostname }
};
@@ -461,7 +468,7 @@ impl ConnectMechanism for TestConnectMechanism {
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_config: &ComputeConfig,
_timeout: std::time::Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
@@ -569,20 +576,6 @@ fn helper_create_connect_info(
user_info
}
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn connect_to_compute_success() {
let _ = env_logger::try_init();
@@ -590,8 +583,12 @@ async fn connect_to_compute_success() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -604,8 +601,12 @@ async fn connect_to_compute_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -619,8 +620,12 @@ async fn connect_to_compute_non_retry_1() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap_err();
mechanism.verify();
@@ -634,8 +639,12 @@ async fn connect_to_compute_non_retry_2() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -656,13 +665,17 @@ async fn connect_to_compute_non_retry_3() {
max_retries: 1,
backoff_factor: 2.0,
};
let config = config();
let connect_to_compute_retry_config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(
&ctx,
&mechanism,
&user_info,
wake_compute_retry_config,
&config,
connect_to_compute_retry_config,
)
.await
.unwrap_err();
@@ -677,8 +690,12 @@ async fn wake_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -692,8 +709,12 @@ async fn wake_non_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -12,7 +12,6 @@ use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::config::ProxyConfig;
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
@@ -40,27 +39,6 @@ pub(crate) enum Notification {
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
@@ -73,24 +51,6 @@ pub(crate) enum Notification {
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
// TODO: change type once the implementation is more fully fledged.
// See e.g. https://github.com/neondatabase/neon/pull/10073.
account_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
@@ -204,11 +164,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
@@ -221,8 +177,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
@@ -249,15 +203,6 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
password_update.role_name,
),
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
}
}
@@ -304,7 +249,6 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
/// Handle console's invalidation messages.
#[tracing::instrument(name = "redis_notifications", skip_all)]
pub async fn task_main<C>(
config: &'static ProxyConfig,
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
@@ -314,7 +258,6 @@ where
C: ProjectInfoCache + Send + Sync + 'static,
{
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
&config.connect_to_compute,
cancel_map,
crate::metrics::CancellationSource::FromRedis,
));

View File

@@ -13,6 +13,7 @@ use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
@@ -58,14 +59,14 @@ enum ExchangeState {
pub(crate) struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
tls_server_end_point: crate::tls::TlsServerEndPoint,
tls_server_end_point: config::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub(crate) fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: crate::tls::TlsServerEndPoint,
tls_server_end_point: config::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial(SaslInitial { nonce }),
@@ -119,7 +120,7 @@ impl SaslInitial {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
let client_first_message = ClientFirstMessage::parse(input)
@@ -154,7 +155,7 @@ impl SaslSentInner {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
let Self {
@@ -167,8 +168,8 @@ impl SaslSentInner {
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
})?;
// This might've been caused by a MITM attack

View File

@@ -77,8 +77,11 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange =
Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined);
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";

View File

@@ -22,7 +22,7 @@ use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
@@ -196,7 +196,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
self.config.connect_to_compute_retry_config,
)
.await
}
@@ -237,7 +237,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
self.config.connect_to_compute_retry_config,
)
.await
}
@@ -502,7 +502,7 @@ impl ConnectMechanism for TokioMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -511,7 +511,7 @@ impl ConnectMechanism for TokioMechanism {
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
.connect_timeout(timeout);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(postgres_client::NoTls).await;
@@ -552,7 +552,7 @@ impl ConnectMechanism for HyperMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -560,7 +560,7 @@ impl ConnectMechanism for HyperMechanism {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let port = node_info.config.get_port();
let res = connect_http2(&host, port, config.timeout).await;
let res = connect_http2(&host, port, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;

View File

@@ -124,6 +124,9 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}

View File

@@ -227,6 +227,9 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}

View File

@@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket(
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),

View File

@@ -11,9 +11,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::tls::TlsServerEndPoint;
/// Stream wrapper which implements libpq's protocol.
///

View File

@@ -1,42 +0,0 @@
use std::sync::Arc;
use anyhow::bail;
use rustls::crypto::ring;
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
bail!("could not parse certificates: {:?}", der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute.
/// This function is blocking.
pub fn compute_client_config_with_root_certs() -> anyhow::Result<rustls::ClientConfig> {
Ok(
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(load_certs()?)
.with_no_client_auth(),
)
}
#[cfg(test)]
pub fn compute_client_config_with_certs(
certs: impl IntoIterator<Item = rustls::pki_types::CertificateDer<'static>>,
) -> rustls::ClientConfig {
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(certs);
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(store)
.with_no_client_auth()
}

View File

@@ -1,72 +0,0 @@
pub mod client_config;
pub mod postgres_rustls;
pub mod server_config;
use anyhow::Context;
use rustls::pki_types::CertificateDer;
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}

View File

@@ -1,218 +0,0 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use anyhow::{bail, Context};
use itertools::Itertools;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use super::{TlsServerEndPoint, PG_ALPN_PROTOCOL};
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}

View File

@@ -398,7 +398,6 @@ def test_timeline_archival_chaos(neon_env_builder: NeonEnvBuilder):
# Offloading is off by default at time of writing: remove this line when it's on by default
neon_env_builder.pageserver_config_override = "timeline_offloading = true"
neon_env_builder.storage_controller_config = {"heartbeat_interval": "100msec"}
neon_env_builder.enable_pageserver_remote_storage(s3_storage())
# We will exercise migrations, so need multiple pageservers