mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-09 05:30:37 +00:00
Compare commits
1 Commits
conrad/pro
...
sk/set_fas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5325f1e9d0 |
2
.github/workflows/build_and_test.yml
vendored
2
.github/workflows/build_and_test.yml
vendored
@@ -212,7 +212,7 @@ jobs:
|
||||
fi
|
||||
echo "CLIPPY_COMMON_ARGS=${CLIPPY_COMMON_ARGS}" >> $GITHUB_ENV
|
||||
- name: Run cargo clippy (debug)
|
||||
run: cargo hack --features default --ignore-unknown-features --feature-powerset clippy $CLIPPY_COMMON_ARGS
|
||||
run: cargo hack --feature-powerset clippy $CLIPPY_COMMON_ARGS
|
||||
|
||||
- name: Check documentation generation
|
||||
run: cargo doc --workspace --no-deps --document-private-items
|
||||
|
||||
@@ -1285,7 +1285,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Compile the Neon-specific `compute_ctl`, `fast_import`, and `local_proxy` binaries
|
||||
# Compile and run the Neon-specific `compute_ctl` and `fast_import` binaries
|
||||
#
|
||||
#########################################################################################
|
||||
FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools
|
||||
@@ -1295,7 +1295,7 @@ ENV BUILD_TAG=$BUILD_TAG
|
||||
USER nonroot
|
||||
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
|
||||
COPY --chown=nonroot . .
|
||||
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy
|
||||
RUN cd compute_tools && mold -run cargo build --locked --profile release-line-debug-size-lto
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
@@ -1338,6 +1338,20 @@ RUN set -e \
|
||||
&& make -j $(nproc) dist_man_MANS= \
|
||||
&& make install dist_man_MANS=
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Compile the Neon-specific `local_proxy` binary
|
||||
#
|
||||
#########################################################################################
|
||||
FROM $REPOSITORY/$IMAGE:$TAG AS local_proxy
|
||||
ARG BUILD_TAG
|
||||
ENV BUILD_TAG=$BUILD_TAG
|
||||
|
||||
USER nonroot
|
||||
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
|
||||
COPY --chown=nonroot . .
|
||||
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin local_proxy
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layers "postgres-exporter" and "sql-exporter"
|
||||
@@ -1477,7 +1491,7 @@ COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/
|
||||
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
|
||||
|
||||
# local_proxy and its config
|
||||
COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
|
||||
COPY --from=local_proxy --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
|
||||
RUN mkdir -p /etc/local_proxy && chown postgres:postgres /etc/local_proxy
|
||||
|
||||
# Metrics exporter binaries and configuration files
|
||||
|
||||
@@ -172,6 +172,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
.args(["-c", &format!("max_worker_processes={nproc}")])
|
||||
.args(["-c", "effective_io_concurrency=100"])
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", "/usr/local/lib/")
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
@@ -256,6 +257,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
.arg(&source_connection_string)
|
||||
// how we run it
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", "/usr/local/lib/")
|
||||
.kill_on_drop(true)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
@@ -289,6 +291,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
.arg(&dumpdir)
|
||||
// how we run it
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", "/usr/local/lib/")
|
||||
.kill_on_drop(true)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::client::SocketConfig;
|
||||
use crate::codec::BackendMessage;
|
||||
use crate::config::Host;
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection};
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@@ -41,7 +43,7 @@ where
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters,
|
||||
delayed_notice: _,
|
||||
delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connect_raw(socket, tls, config).await?;
|
||||
@@ -61,7 +63,13 @@ where
|
||||
secret_key,
|
||||
);
|
||||
|
||||
let connection = Connection::new(stream, parameters, receiver);
|
||||
// delayed notices are always sent as "Async" messages.
|
||||
let delayed = delayed_notice
|
||||
.into_iter()
|
||||
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
|
||||
.collect();
|
||||
|
||||
let connection = Connection::new(stream, delayed, parameters, receiver);
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use crate::error::DbError;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::{AsyncMessage, Error, Notification};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Sink, Stream};
|
||||
use log::trace;
|
||||
use log::{info, trace};
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
@@ -32,12 +33,8 @@ pub struct Response {
|
||||
#[derive(PartialEq, Debug)]
|
||||
enum State {
|
||||
Active,
|
||||
Closing,
|
||||
}
|
||||
|
||||
enum WriteReady {
|
||||
Terminating,
|
||||
WaitingOnRead,
|
||||
Closing,
|
||||
}
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
@@ -54,7 +51,8 @@ pub struct Connection<S, T> {
|
||||
/// HACK: we need this in the Neon Proxy to forward params.
|
||||
pub parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_responses: Option<BackendMessage>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
responses: VecDeque<Response>,
|
||||
state: State,
|
||||
}
|
||||
@@ -66,6 +64,7 @@ where
|
||||
{
|
||||
pub(crate) fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
) -> Connection<S, T> {
|
||||
@@ -73,7 +72,8 @@ where
|
||||
stream,
|
||||
parameters,
|
||||
receiver,
|
||||
pending_responses: None,
|
||||
pending_request: None,
|
||||
pending_responses,
|
||||
responses: VecDeque::new(),
|
||||
state: State::Active,
|
||||
}
|
||||
@@ -83,7 +83,7 @@ where
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
||||
if let Some(message) = self.pending_responses.take() {
|
||||
if let Some(message) = self.pending_responses.pop_front() {
|
||||
trace!("retrying pending response");
|
||||
return Poll::Ready(Some(Ok(message)));
|
||||
}
|
||||
@@ -93,22 +93,26 @@ where
|
||||
.map(|o| o.map(|r| r.map_err(Error::io)))
|
||||
}
|
||||
|
||||
/// Read and process messages from the connection to postgres.
|
||||
/// client <- postgres
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
|
||||
if self.state != State::Active {
|
||||
trace!("poll_read: done");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = match self.poll_response(cx)? {
|
||||
Poll::Ready(Some(message)) => message,
|
||||
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
|
||||
Poll::Ready(None) => return Err(Error::closed()),
|
||||
Poll::Pending => {
|
||||
trace!("poll_read: waiting on response");
|
||||
return Poll::Pending;
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let (mut messages, request_complete) = match message {
|
||||
BackendMessage::Async(Message::NoticeResponse(_)) => {
|
||||
continue;
|
||||
BackendMessage::Async(Message::NoticeResponse(body)) => {
|
||||
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
|
||||
return Ok(Some(AsyncMessage::Notice(error)));
|
||||
}
|
||||
BackendMessage::Async(Message::NotificationResponse(body)) => {
|
||||
let notification = Notification {
|
||||
@@ -116,7 +120,7 @@ where
|
||||
channel: body.channel().map_err(Error::parse)?.to_string(),
|
||||
payload: body.message().map_err(Error::parse)?.to_string(),
|
||||
};
|
||||
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
|
||||
return Ok(Some(AsyncMessage::Notification(notification)));
|
||||
}
|
||||
BackendMessage::Async(Message::ParameterStatus(body)) => {
|
||||
self.parameters.insert(
|
||||
@@ -135,10 +139,8 @@ where
|
||||
let mut response = match self.responses.pop_front() {
|
||||
Some(response) => response,
|
||||
None => match messages.next().map_err(Error::parse)? {
|
||||
Some(Message::ErrorResponse(error)) => {
|
||||
return Poll::Ready(Err(Error::db(error)))
|
||||
}
|
||||
_ => return Poll::Ready(Err(Error::unexpected_message())),
|
||||
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
},
|
||||
};
|
||||
|
||||
@@ -157,19 +159,23 @@ where
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.responses.push_front(response);
|
||||
self.pending_responses = Some(BackendMessage::Normal {
|
||||
self.pending_responses.push_back(BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
});
|
||||
trace!("poll_read: waiting on sender");
|
||||
return Poll::Pending;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the next client request and enqueue the response sender.
|
||||
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
|
||||
if let Some(messages) = self.pending_request.take() {
|
||||
trace!("retrying pending request");
|
||||
return Poll::Ready(Some(messages));
|
||||
}
|
||||
|
||||
if self.receiver.is_closed() {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
@@ -187,80 +193,74 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Process client requests and write them to the postgres connection, flushing if necessary.
|
||||
/// client -> postgres
|
||||
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
|
||||
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
|
||||
loop {
|
||||
if self.state == State::Closing {
|
||||
trace!("poll_write: done");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if Pin::new(&mut self.stream)
|
||||
.poll_ready(cx)
|
||||
.map_err(Error::io)?
|
||||
.is_pending()
|
||||
{
|
||||
trace!("poll_write: waiting on socket");
|
||||
|
||||
// poll_ready is self-flushing.
|
||||
return Poll::Pending;
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
match self.poll_request(cx) {
|
||||
// send the message to postgres
|
||||
Poll::Ready(Some(RequestMessages::Single(request))) => {
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request)
|
||||
.map_err(Error::io)?;
|
||||
}
|
||||
// No more messages from the client, and no more responses to wait for.
|
||||
// Send a terminate message to postgres
|
||||
Poll::Ready(None) if self.responses.is_empty() => {
|
||||
let request = match self.poll_request(cx) {
|
||||
Poll::Ready(Some(request)) => request,
|
||||
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
|
||||
trace!("poll_write: at eof, terminating");
|
||||
self.state = State::Terminating;
|
||||
let mut request = BytesMut::new();
|
||||
frontend::terminate(&mut request);
|
||||
let request = FrontendMessage::Raw(request.freeze());
|
||||
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request)
|
||||
.map_err(Error::io)?;
|
||||
|
||||
trace!("poll_write: sent eof, closing");
|
||||
trace!("poll_write: done");
|
||||
return Poll::Ready(Ok(WriteReady::Terminating));
|
||||
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
|
||||
}
|
||||
// No more messages from the client, but there are still some responses to wait for.
|
||||
Poll::Ready(None) => {
|
||||
trace!(
|
||||
"poll_write: at eof, pending responses {}",
|
||||
self.responses.len()
|
||||
);
|
||||
ready!(self.poll_flush(cx))?;
|
||||
return Poll::Ready(Ok(WriteReady::WaitingOnRead));
|
||||
return Ok(true);
|
||||
}
|
||||
// Still waiting for a message from the client.
|
||||
Poll::Pending => {
|
||||
trace!("poll_write: waiting on request");
|
||||
ready!(self.poll_flush(cx))?;
|
||||
return Poll::Pending;
|
||||
return Ok(true);
|
||||
}
|
||||
};
|
||||
|
||||
match request {
|
||||
RequestMessages::Single(request) => {
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request)
|
||||
.map_err(Error::io)?;
|
||||
if self.state == State::Terminating {
|
||||
trace!("poll_write: sent eof, closing");
|
||||
self.state = State::Closing;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_flush(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
trace!("poll_flush: flushed");
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => {
|
||||
trace!("poll_flush: waiting on socket");
|
||||
Poll::Pending
|
||||
}
|
||||
Poll::Ready(()) => trace!("poll_flush: flushed"),
|
||||
Poll::Pending => trace!("poll_flush: waiting on socket"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if self.state != State::Closing {
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_close(cx)
|
||||
.map_err(Error::io)?
|
||||
@@ -289,30 +289,18 @@ where
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<AsyncMessage, Error>>> {
|
||||
if self.state != State::Closing {
|
||||
// if the state is still active, try read from and write to postgres.
|
||||
let message = self.poll_read(cx)?;
|
||||
let closing = self.poll_write(cx)?;
|
||||
if let Poll::Ready(WriteReady::Terminating) = closing {
|
||||
self.state = State::Closing;
|
||||
}
|
||||
|
||||
if let Poll::Ready(message) = message {
|
||||
return Poll::Ready(Some(Ok(message)));
|
||||
}
|
||||
|
||||
// poll_read returned Pending.
|
||||
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
|
||||
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
|
||||
if self.state != State::Closing {
|
||||
return Poll::Pending;
|
||||
}
|
||||
let message = self.poll_read(cx)?;
|
||||
let want_flush = self.poll_write(cx)?;
|
||||
if want_flush {
|
||||
self.poll_flush(cx)?;
|
||||
}
|
||||
|
||||
match self.poll_shutdown(cx) {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(None),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Pending => Poll::Pending,
|
||||
match message {
|
||||
Some(message) => Poll::Ready(Some(Ok(message))),
|
||||
None => match self.poll_shutdown(cx) {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(None),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Pending => Poll::Pending,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -325,7 +313,11 @@ where
|
||||
type Output = Result<(), Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
while ready!(self.poll_message(cx)?).is_some() {}
|
||||
while let Some(message) = ready!(self.poll_message(cx)?) {
|
||||
if let AsyncMessage::Notice(notice) = message {
|
||||
info!("{}: {}", notice.severity(), notice.message());
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use crate::client::{Client, SocketConfig};
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
pub use crate::connection::Connection;
|
||||
use crate::error::DbError;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::generic_client::GenericClient;
|
||||
pub use crate::query::RowStream;
|
||||
@@ -99,6 +100,10 @@ impl Notification {
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum AsyncMessage {
|
||||
/// A notice.
|
||||
///
|
||||
/// Notices use the same format as errors, but aren't "errors" per-se.
|
||||
Notice(DbError),
|
||||
/// A notification.
|
||||
///
|
||||
/// Connections can subscribe to notifications with the `LISTEN` command.
|
||||
|
||||
@@ -541,7 +541,6 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
|
||||
}
|
||||
else
|
||||
{
|
||||
LWLockRelease(lfc_lock);
|
||||
return found;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ use tracing::info;
|
||||
|
||||
use super::backend::ComputeCredentialKeys;
|
||||
use super::{AuthError, PasswordHackPayload};
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
@@ -17,7 +18,6 @@ use crate::sasl;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::scram::{self};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub(crate) trait AuthMethod {
|
||||
|
||||
@@ -13,9 +13,7 @@ use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
|
||||
use proxy::auth::{self};
|
||||
use proxy::cancellation::CancellationHandlerMain;
|
||||
use proxy::config::{
|
||||
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
|
||||
};
|
||||
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
|
||||
use proxy::control_plane::locks::ApiLocks;
|
||||
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
||||
use proxy::http::health_server::AppMetrics;
|
||||
@@ -27,7 +25,6 @@ use proxy::rate_limiter::{
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::{self, GlobalConnPoolOptions};
|
||||
use proxy::tls::client_config::compute_client_config_with_root_certs;
|
||||
use proxy::types::RoleName;
|
||||
use proxy::url::ApiUrl;
|
||||
|
||||
@@ -212,7 +209,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandlerMain::new(
|
||||
&config.connect_to_compute,
|
||||
Arc::new(DashMap::new()),
|
||||
None,
|
||||
proxy::metrics::CancellationSource::Local,
|
||||
@@ -272,12 +268,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||
};
|
||||
|
||||
let compute_config = ComputeConfig {
|
||||
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
|
||||
tls: Arc::new(compute_client_config_with_root_certs()?),
|
||||
timeout: Duration::from_secs(2),
|
||||
};
|
||||
|
||||
Ok(Box::leak(Box::new(ProxyConfig {
|
||||
tls_config: None,
|
||||
metric_collection: None,
|
||||
@@ -299,7 +289,9 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
region: "local".into(),
|
||||
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute: compute_config,
|
||||
connect_to_compute_retry_config: RetryConfig::parse(
|
||||
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
|
||||
)?,
|
||||
})))
|
||||
}
|
||||
|
||||
|
||||
@@ -10,12 +10,12 @@ use clap::Arg;
|
||||
use futures::future::Either;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use proxy::context::RequestContext;
|
||||
use proxy::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use proxy::protocol2::ConnectionInfo;
|
||||
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
|
||||
use proxy::stream::{PqStream, Stream};
|
||||
use proxy::tls::TlsServerEndPoint;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types::PrivateKeyDer;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::bail;
|
||||
use futures::future::Either;
|
||||
@@ -9,7 +8,7 @@ use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
|
||||
use proxy::cancellation::{CancelMap, CancellationHandler};
|
||||
use proxy::config::{
|
||||
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
|
||||
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
|
||||
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
|
||||
};
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
@@ -24,7 +23,6 @@ use proxy::redis::{elasticache, notifications};
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::tls::client_config::compute_client_config_with_root_certs;
|
||||
use proxy::{auth, control_plane, http, serverless, usage_metrics};
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use tokio::net::TcpListener;
|
||||
@@ -399,7 +397,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<
|
||||
Option<Arc<Mutex<RedisPublisherClient>>>,
|
||||
>::new(
|
||||
&config.connect_to_compute,
|
||||
cancel_map.clone(),
|
||||
redis_publisher,
|
||||
proxy::metrics::CancellationSource::FromClient,
|
||||
@@ -495,7 +492,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
config,
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
@@ -504,7 +500,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
config,
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
@@ -637,12 +632,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
};
|
||||
|
||||
let compute_config = ComputeConfig {
|
||||
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
|
||||
tls: Arc::new(compute_client_config_with_root_certs()?),
|
||||
timeout: Duration::from_secs(2),
|
||||
};
|
||||
|
||||
let config = ProxyConfig {
|
||||
tls_config,
|
||||
metric_collection,
|
||||
@@ -653,7 +642,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
region: args.region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute: compute_config,
|
||||
connect_to_compute_retry_config: config::RetryConfig::parse(
|
||||
&args.connect_to_compute_retry,
|
||||
)?,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(config));
|
||||
|
||||
@@ -3,9 +3,11 @@ use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::CancelToken;
|
||||
use pq_proto::CancelKeyData;
|
||||
use rustls::crypto::ring;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -13,15 +15,15 @@ use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::compute::load_certs;
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::cancellation_publisher::{
|
||||
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
|
||||
};
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
|
||||
@@ -33,7 +35,6 @@ type IpSubnetKey = IpNet;
|
||||
///
|
||||
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler<P> {
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
client: P,
|
||||
/// This field used for the monitoring purposes.
|
||||
@@ -182,7 +183,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
"cancelling query per user's request using key {key}, hostname {}, address: {}",
|
||||
cancel_closure.hostname, cancel_closure.socket_addr
|
||||
);
|
||||
cancel_closure.try_cancel_query(self.compute_config).await
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -197,13 +198,8 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
}
|
||||
|
||||
impl CancellationHandler<()> {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
from: CancellationSource,
|
||||
) -> Self {
|
||||
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
map,
|
||||
client: (),
|
||||
from,
|
||||
@@ -218,14 +214,8 @@ impl CancellationHandler<()> {
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
client: Option<Arc<Mutex<P>>>,
|
||||
from: CancellationSource,
|
||||
) -> Self {
|
||||
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
map,
|
||||
client,
|
||||
from,
|
||||
@@ -239,6 +229,8 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
|
||||
}
|
||||
}
|
||||
|
||||
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
@@ -265,14 +257,27 @@ impl CancelClosure {
|
||||
}
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub(crate) async fn try_cancel_query(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), CancelError> {
|
||||
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
|
||||
let mut mk_tls =
|
||||
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
|
||||
let root_store = TLS_ROOTS
|
||||
.get_or_try_init(load_certs)
|
||||
.map_err(|_e| {
|
||||
CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"TLS root store initialization failed".to_string(),
|
||||
))
|
||||
})?
|
||||
.clone();
|
||||
|
||||
let client_config =
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
&self.hostname,
|
||||
@@ -324,30 +329,11 @@ impl<P> Drop for Session<P> {
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
use crate::config::RetryConfig;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
|
||||
fn config() -> ComputeConfig {
|
||||
let retry = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
|
||||
ComputeConfig {
|
||||
retry,
|
||||
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
|
||||
timeout: Duration::from_secs(2),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
Box::leak(Box::new(config())),
|
||||
CancelMap::default(),
|
||||
CancellationSource::FromRedis,
|
||||
));
|
||||
@@ -363,11 +349,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_session_noop_regression() {
|
||||
let handler = CancellationHandler::<()>::new(
|
||||
Box::leak(Box::new(config())),
|
||||
CancelMap::default(),
|
||||
CancellationSource::Local,
|
||||
);
|
||||
let handler =
|
||||
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
|
||||
handler
|
||||
.cancel_session(
|
||||
CancelKeyData {
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
@@ -15,15 +18,14 @@ use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::parse_endpoint_param;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::types::Host;
|
||||
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
@@ -38,6 +40,9 @@ pub(crate) enum ConnectionError {
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
|
||||
#[error("Couldn't load native TLS certificates: {0:?}")]
|
||||
TlsCertificateError(Vec<rustls_native_certs::Error>),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] InvalidDnsNameError),
|
||||
|
||||
@@ -84,6 +89,7 @@ impl ReportableError for ConnectionError {
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
@@ -193,15 +199,11 @@ impl ConnCfg {
|
||||
|
||||
let connect_once = |host, port| {
|
||||
debug!("trying to connect to compute node at {host}:{port}");
|
||||
connect_with_timeout(host, port).and_then(|stream| async {
|
||||
let socket_addr = stream.peer_addr()?;
|
||||
let socket = socket2::SockRef::from(&stream);
|
||||
// Disable Nagle's algorithm to not introduce latency between
|
||||
// client and compute.
|
||||
socket.set_nodelay(true)?;
|
||||
connect_with_timeout(host, port).and_then(|socket| async {
|
||||
let socket_addr = socket.peer_addr()?;
|
||||
// This prevents load balancer from severing the connection.
|
||||
socket.set_keepalive(true)?;
|
||||
Ok((socket_addr, stream))
|
||||
socket2::SockRef::from(&socket).set_keepalive(true)?;
|
||||
Ok((socket_addr, socket))
|
||||
})
|
||||
};
|
||||
|
||||
@@ -249,13 +251,25 @@ impl ConnCfg {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
aux: MetricsAuxInfo,
|
||||
config: &ComputeConfig,
|
||||
timeout: Duration,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
drop(pause);
|
||||
|
||||
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
|
||||
let root_store = TLS_ROOTS
|
||||
.get_or_try_init(load_certs)
|
||||
.map_err(ConnectionError::TlsCertificateError)?
|
||||
.clone();
|
||||
|
||||
let client_config =
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
@@ -327,6 +341,19 @@ fn filtered_options(options: &str) -> Option<String> {
|
||||
Some(options)
|
||||
}
|
||||
|
||||
pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
|
||||
let der_certs = rustls_native_certs::load_native_certs();
|
||||
|
||||
if !der_certs.errors.is_empty() {
|
||||
return Err(der_certs.errors);
|
||||
}
|
||||
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs.certs);
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use clap::ValueEnum;
|
||||
use itertools::Itertools;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use rustls::crypto::ring::{self, sign};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::auth::backend::AuthRateLimiter;
|
||||
@@ -13,7 +20,6 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
use crate::serverless::GlobalConnPoolOptions;
|
||||
pub use crate::tls::server_config::{configure_tls, TlsConfig};
|
||||
use crate::types::Host;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
@@ -26,13 +32,7 @@ pub struct ProxyConfig {
|
||||
pub handshake_timeout: Duration,
|
||||
pub wake_compute_retry_config: RetryConfig,
|
||||
pub connect_compute_locks: ApiLocks<Host>,
|
||||
pub connect_to_compute: ComputeConfig,
|
||||
}
|
||||
|
||||
pub struct ComputeConfig {
|
||||
pub retry: RetryConfig,
|
||||
pub tls: Arc<rustls::ClientConfig>,
|
||||
pub timeout: Duration,
|
||||
pub connect_to_compute_retry_config: RetryConfig,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
|
||||
@@ -52,6 +52,12 @@ pub struct MetricCollectionConfig {
|
||||
pub backup_metric_collection_config: MetricBackupCollectionConfig,
|
||||
}
|
||||
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: HashSet<String>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
pub struct HttpConfig {
|
||||
pub accept_websockets: bool,
|
||||
pub pool_options: GlobalConnPoolOptions,
|
||||
@@ -74,6 +80,272 @@ pub struct AuthenticationConfig {
|
||||
pub console_redirect_confirmation_timeout: tokio::time::Duration,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
|
||||
self.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
|
||||
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
certs_dir: Option<&String>,
|
||||
allow_tls_keylogfile: bool,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
for entry in std::fs::read_dir(certs_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
// file names aligned with default cert-manager names
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
let mut config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
.context("ring should support TLS1.2 and TLS1.3")?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver.clone());
|
||||
|
||||
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
|
||||
|
||||
if allow_tls_keylogfile {
|
||||
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
|
||||
config.key_log = Arc::new(rustls::KeyLogFile::new());
|
||||
}
|
||||
|
||||
Ok(TlsConfig {
|
||||
config: Arc::new(config),
|
||||
common_names,
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(first_cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
}
|
||||
if let Some((_, rest)) = sni_name.split_once('.') {
|
||||
sni_name = rest;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No SNI, use the default certificate, otherwise we can't get to
|
||||
// options parameter which can be used to set endpoint name too.
|
||||
// That means that non-SNI flow will not work for CNAME domains in
|
||||
// verify-full mode.
|
||||
//
|
||||
// If that will be a problem we can:
|
||||
//
|
||||
// a) Instead of multi-cert approach use single cert with extra
|
||||
// domains listed in Subject Alternative Name (SAN).
|
||||
// b) Deploy separate proxy instances for extra domains.
|
||||
self.default.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EndpointCacheConfig {
|
||||
/// Batch size to receive all endpoints on the startup.
|
||||
|
||||
@@ -115,7 +115,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
|
||||
@@ -216,7 +216,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
config.connect_to_compute_retry_config,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
@@ -10,13 +10,13 @@ pub mod client;
|
||||
pub(crate) mod errors;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
|
||||
use crate::intern::ProjectIdInt;
|
||||
@@ -73,9 +73,9 @@ impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
config: &ComputeConfig,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config.connect(ctx, self.aux.clone(), config).await
|
||||
self.config.connect(ctx, self.aux.clone(), timeout).await
|
||||
}
|
||||
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
|
||||
@@ -89,6 +89,7 @@ pub mod jemalloc;
|
||||
pub mod logging;
|
||||
pub mod metrics;
|
||||
pub mod parse;
|
||||
pub mod postgres_rustls;
|
||||
pub mod protocol2;
|
||||
pub mod proxy;
|
||||
pub mod rate_limiter;
|
||||
@@ -98,7 +99,6 @@ pub mod scram;
|
||||
pub mod serverless;
|
||||
pub mod signals;
|
||||
pub mod stream;
|
||||
pub mod tls;
|
||||
pub mod types;
|
||||
pub mod url;
|
||||
pub mod usage_metrics;
|
||||
|
||||
@@ -18,7 +18,7 @@ mod private {
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
use crate::config::TlsServerEndPoint;
|
||||
|
||||
pub struct TlsConnectFuture<S> {
|
||||
inner: tokio_rustls::Connect<S>,
|
||||
@@ -126,14 +126,16 @@ mod private {
|
||||
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
||||
#[derive(Clone)]
|
||||
pub struct MakeRustlsConnect {
|
||||
pub config: Arc<ClientConfig>,
|
||||
config: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
impl MakeRustlsConnect {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
||||
#[must_use]
|
||||
pub fn new(config: Arc<ClientConfig>) -> Self {
|
||||
Self { config }
|
||||
pub fn new(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
config: Arc::new(config),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use tracing::{debug, info, warn};
|
||||
use super::retry::ShouldRetryWakeCompute;
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::config::RetryConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
@@ -19,6 +19,8 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::types::Host;
|
||||
|
||||
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
|
||||
/// If we couldn't connect, a cached connection info might be to blame
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
@@ -47,7 +49,7 @@ pub(crate) trait ConnectMechanism {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
@@ -84,11 +86,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
permit.release_result(node_info.connect(ctx, config).await)
|
||||
permit.release_result(node_info.connect(ctx, timeout).await)
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
@@ -103,7 +105,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBac
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
compute: &ComputeConfig,
|
||||
connect_to_compute_retry_config: RetryConfig,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
|
||||
@@ -117,7 +119,10 @@ where
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||
let err = match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
@@ -137,7 +142,7 @@ where
|
||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if should_retry(&err, num_retries, compute.retry) {
|
||||
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
@@ -167,7 +172,10 @@ where
|
||||
debug!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
@@ -182,7 +190,7 @@ where
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
if !should_retry(&e, num_retries, compute.retry) {
|
||||
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
|
||||
// Don't log an error here, caller will print the error
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
@@ -198,7 +206,7 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
let wait_duration = retry_after(num_retries, compute.retry);
|
||||
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
|
||||
num_retries += 1;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
|
||||
@@ -8,13 +8,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::config::{TlsConfig, PG_ALPN_PROTOCOL};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::ERR_INSECURE_CONNECTION;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum HandshakeError {
|
||||
|
||||
@@ -152,7 +152,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
|
||||
@@ -351,7 +351,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
config.connect_to_compute_retry_config,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
@@ -5,7 +5,6 @@ use utils::measured_stream::MeasuredStream;
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::stream::Stream;
|
||||
@@ -68,17 +67,9 @@ pub(crate) struct ProxyPassthrough<P, S> {
|
||||
}
|
||||
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
res
|
||||
|
||||
@@ -22,16 +22,14 @@ use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
|
||||
};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::config::{CertResolver, RetryConfig};
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{
|
||||
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
|
||||
};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
|
||||
@@ -69,7 +67,7 @@ fn generate_certs(
|
||||
}
|
||||
|
||||
struct ClientConfig<'a> {
|
||||
config: Arc<rustls::ClientConfig>,
|
||||
config: rustls::ClientConfig,
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
@@ -112,7 +110,16 @@ fn generate_tls_config<'a>(
|
||||
};
|
||||
|
||||
let client_config = {
|
||||
let config = Arc::new(compute_client_config_with_certs([ca]));
|
||||
let config =
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.context("ring should support the default protocol versions")?
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(ca)?;
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
|
||||
ClientConfig { config, hostname }
|
||||
};
|
||||
@@ -461,7 +468,7 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_node_info: &control_plane::CachedNodeInfo,
|
||||
_config: &ComputeConfig,
|
||||
_timeout: std::time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let action = self.sequence[*counter];
|
||||
@@ -569,20 +576,6 @@ fn helper_create_connect_info(
|
||||
user_info
|
||||
}
|
||||
|
||||
fn config() -> ComputeConfig {
|
||||
let retry = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
|
||||
ComputeConfig {
|
||||
retry,
|
||||
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
|
||||
timeout: Duration::from_secs(2),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
let _ = env_logger::try_init();
|
||||
@@ -590,8 +583,12 @@ async fn connect_to_compute_success() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -604,8 +601,12 @@ async fn connect_to_compute_retry() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -619,8 +620,12 @@ async fn connect_to_compute_non_retry_1() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -634,8 +639,12 @@ async fn connect_to_compute_non_retry_2() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -656,13 +665,17 @@ async fn connect_to_compute_non_retry_3() {
|
||||
max_retries: 1,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
let config = config();
|
||||
let connect_to_compute_retry_config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(
|
||||
&ctx,
|
||||
&mechanism,
|
||||
&user_info,
|
||||
wake_compute_retry_config,
|
||||
&config,
|
||||
connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
@@ -677,8 +690,12 @@ async fn wake_retry() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -692,8 +709,12 @@ async fn wake_non_retry() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -12,7 +12,6 @@ use uuid::Uuid;
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::cancellation::{CancelMap, CancellationHandler};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::intern::{ProjectIdInt, RoleNameInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
|
||||
@@ -40,27 +39,6 @@ pub(crate) enum Notification {
|
||||
AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/block_public_or_vpc_access_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
BlockPublicOrVpcAccessUpdated {
|
||||
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/allowed_vpc_endpoints_updated_for_org",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedVpcEndpointsUpdatedForOrg {
|
||||
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/allowed_vpc_endpoints_updated_for_projects",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedVpcEndpointsUpdatedForProjects {
|
||||
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
@@ -73,24 +51,6 @@ pub(crate) enum Notification {
|
||||
pub(crate) struct AllowedIpsUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct BlockPublicOrVpcAccessUpdated {
|
||||
project_id: ProjectIdInt,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
|
||||
// TODO: change type once the implementation is more fully fledged.
|
||||
// See e.g. https://github.com/neondatabase/neon/pull/10073.
|
||||
account_id: ProjectIdInt,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
|
||||
project_ids: Vec<ProjectIdInt>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct PasswordUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
@@ -204,11 +164,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Notification::AllowedIpsUpdate { .. }
|
||||
| Notification::PasswordUpdate { .. }
|
||||
| Notification::BlockPublicOrVpcAccessUpdated { .. }
|
||||
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
|
||||
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
|
||||
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
|
||||
Metrics::get()
|
||||
@@ -221,8 +177,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::PasswordUpdate);
|
||||
}
|
||||
// TODO: add additional metrics for the other event types.
|
||||
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
@@ -249,15 +203,6 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
password_update.role_name,
|
||||
),
|
||||
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
|
||||
// https://github.com/neondatabase/neon/pull/10073
|
||||
}
|
||||
Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
|
||||
// https://github.com/neondatabase/neon/pull/10073
|
||||
}
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
|
||||
// https://github.com/neondatabase/neon/pull/10073
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,7 +249,6 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "redis_notifications", skip_all)]
|
||||
pub async fn task_main<C>(
|
||||
config: &'static ProxyConfig,
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cache: Arc<C>,
|
||||
cancel_map: CancelMap,
|
||||
@@ -314,7 +258,6 @@ where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
&config.connect_to_compute,
|
||||
cancel_map,
|
||||
crate::metrics::CancellationSource::FromRedis,
|
||||
));
|
||||
|
||||
@@ -13,6 +13,7 @@ use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use super::threadpool::ThreadPool;
|
||||
use super::ScramKey;
|
||||
use crate::config;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
@@ -58,14 +59,14 @@ enum ExchangeState {
|
||||
pub(crate) struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
tls_server_end_point: crate::tls::TlsServerEndPoint,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub(crate) fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
tls_server_end_point: crate::tls::TlsServerEndPoint,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial(SaslInitial { nonce }),
|
||||
@@ -119,7 +120,7 @@ impl SaslInitial {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &crate::tls::TlsServerEndPoint,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
@@ -154,7 +155,7 @@ impl SaslSentInner {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &crate::tls::TlsServerEndPoint,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
|
||||
let Self {
|
||||
@@ -167,8 +168,8 @@ impl SaslSentInner {
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
|
||||
crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
|
||||
@@ -77,8 +77,11 @@ mod tests {
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
let mut exchange =
|
||||
Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined);
|
||||
let mut exchange = Exchange::new(
|
||||
&secret,
|
||||
|| NONCE,
|
||||
crate::config::TlsServerEndPoint::Undefined,
|
||||
);
|
||||
|
||||
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::compute;
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
};
|
||||
use crate::config::{ComputeConfig, ProxyConfig};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
@@ -196,7 +196,7 @@ impl PoolingBackend {
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -237,7 +237,7 @@ impl PoolingBackend {
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -502,7 +502,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &CachedNodeInfo,
|
||||
compute_config: &ComputeConfig,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
@@ -511,7 +511,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
let config = config
|
||||
.user(&self.conn_info.user_info.user)
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(compute_config.timeout);
|
||||
.connect_timeout(timeout);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(postgres_client::NoTls).await;
|
||||
@@ -552,7 +552,7 @@ impl ConnectMechanism for HyperMechanism {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
@@ -560,7 +560,7 @@ impl ConnectMechanism for HyperMechanism {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
let port = node_info.config.get_port();
|
||||
let res = connect_http2(&host, port, config.timeout).await;
|
||||
let res = connect_http2(&host, port, timeout).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
|
||||
@@ -124,6 +124,9 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
let message = ready!(connection.poll_message(cx));
|
||||
|
||||
match message {
|
||||
Some(Ok(AsyncMessage::Notice(notice))) => {
|
||||
info!(%session_id, "notice: {}", notice);
|
||||
}
|
||||
Some(Ok(AsyncMessage::Notification(notif))) => {
|
||||
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
|
||||
}
|
||||
|
||||
@@ -227,6 +227,9 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
let message = ready!(connection.poll_message(cx));
|
||||
|
||||
match message {
|
||||
Some(Ok(AsyncMessage::Notice(notice))) => {
|
||||
info!(%session_id, "notice: {}", notice);
|
||||
}
|
||||
Some(Ok(AsyncMessage::Notification(notif))) => {
|
||||
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
|
||||
@@ -11,9 +11,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
///
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use rustls::crypto::ring;
|
||||
|
||||
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
|
||||
let der_certs = rustls_native_certs::load_native_certs();
|
||||
|
||||
if !der_certs.errors.is_empty() {
|
||||
bail!("could not parse certificates: {:?}", der_certs.errors);
|
||||
}
|
||||
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs.certs);
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
|
||||
/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute.
|
||||
/// This function is blocking.
|
||||
pub fn compute_client_config_with_root_certs() -> anyhow::Result<rustls::ClientConfig> {
|
||||
Ok(
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(load_certs()?)
|
||||
.with_no_client_auth(),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn compute_client_config_with_certs(
|
||||
certs: impl IntoIterator<Item = rustls::pki_types::CertificateDer<'static>>,
|
||||
) -> rustls::ClientConfig {
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(certs);
|
||||
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(store)
|
||||
.with_no_client_auth()
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
pub mod client_config;
|
||||
pub mod postgres_rustls;
|
||||
pub mod server_config;
|
||||
|
||||
use anyhow::Context;
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
|
||||
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
|
||||
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use itertools::Itertools;
|
||||
use rustls::crypto::ring::{self, sign};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
|
||||
use super::{TlsServerEndPoint, PG_ALPN_PROTOCOL};
|
||||
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: HashSet<String>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
|
||||
self.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
certs_dir: Option<&String>,
|
||||
allow_tls_keylogfile: bool,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
for entry in std::fs::read_dir(certs_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
// file names aligned with default cert-manager names
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
let mut config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
.context("ring should support TLS1.2 and TLS1.3")?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver.clone());
|
||||
|
||||
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
|
||||
|
||||
if allow_tls_keylogfile {
|
||||
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
|
||||
config.key_log = Arc::new(rustls::KeyLogFile::new());
|
||||
}
|
||||
|
||||
Ok(TlsConfig {
|
||||
config: Arc::new(config),
|
||||
common_names,
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(first_cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
}
|
||||
if let Some((_, rest)) = sni_name.split_once('.') {
|
||||
sni_name = rest;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No SNI, use the default certificate, otherwise we can't get to
|
||||
// options parameter which can be used to set endpoint name too.
|
||||
// That means that non-SNI flow will not work for CNAME domains in
|
||||
// verify-full mode.
|
||||
//
|
||||
// If that will be a problem we can:
|
||||
//
|
||||
// a) Instead of multi-cert approach use single cert with extra
|
||||
// domains listed in Subject Alternative Name (SAN).
|
||||
// b) Deploy separate proxy instances for extra domains.
|
||||
self.default.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -398,7 +398,6 @@ def test_timeline_archival_chaos(neon_env_builder: NeonEnvBuilder):
|
||||
|
||||
# Offloading is off by default at time of writing: remove this line when it's on by default
|
||||
neon_env_builder.pageserver_config_override = "timeline_offloading = true"
|
||||
neon_env_builder.storage_controller_config = {"heartbeat_interval": "100msec"}
|
||||
neon_env_builder.enable_pageserver_remote_storage(s3_storage())
|
||||
|
||||
# We will exercise migrations, so need multiple pageservers
|
||||
|
||||
Reference in New Issue
Block a user