Compare commits

...

12 Commits

Author SHA1 Message Date
Conrad Ludgate
1d644b826d chore(libs/proxy): make pending_repsonses an option instead of vecdeque, as it can only store 1 item 2025-01-02 16:31:27 +00:00
Conrad Ludgate
db1de41a43 chore(libs/proxy): do not propagate notices as pending_responses to postgres connection 2025-01-02 16:31:27 +00:00
Conrad Ludgate
79b8392d30 chore(libs/proxy): remove unused notice message type 2025-01-02 16:31:27 +00:00
Conrad Ludgate
5c6dc2f695 fix(proxy): remove postgres notice logs 2025-01-02 16:31:27 +00:00
Conrad Ludgate
56e6ebfe17 chore: building compute_tools and local_proxy together (#10257)
## Problem

Building local_proxy and compute_tools features the same dependency
tree, but as they are currently built in separate clean layers all that
progress is wasted. For our arm builds that's an extra 10 minutes.

## Summary of changes

Combines the compute_tools and local_proxy build layers.
2025-01-02 16:05:14 +00:00
Raphael 'kena' Poss
1622fd8bda proxy: recognize but ignore the 3 new redis message types (#10197)
## Problem

https://neondb.slack.com/archives/C085MBDUSS2/p1734604792755369

## Summary of changes

Recognize and ignore the 3 new broadcast messages:
- `/block_public_or_vpc_access_updated`
- `/allowed_vpc_endpoints_updated_for_org`
- `/allowed_vpc_endpoints_updated_for_projects`
2025-01-02 16:02:48 +00:00
Konstantin Knizhnik
8c7dcd2598 Set heartbeat interval for chaos test (#10222)
## Problem

See https://neondb.slack.com/archives/C033RQ5SPDH/p1734707873215729

test_timeline_archival_chaos becomes more flaky with increased heartbeat
interval

Resolves #10250.

## Summary of changes

Override heatbeat interval for `test_timelirn_archival_chaos.py`

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2025-01-02 14:14:18 +00:00
Folke Behrens
ee22d4c9ef proxy: Set TCP_NODELAY for compute connections (#10240)
neondatabase/cloud#19184
2025-01-02 13:32:24 +00:00
JC Grünhage
26600f2973 Skip running clippy without default features (#10098)
## Problem

Running clippy with `cargo hack --feature-powerset` in CI isn't
particularly fast. This PR follows-up on
https://github.com/neondatabase/neon/pull/8912 to improve the speed of
our clippy runs.

Parallelism as suggested in
https://github.com/neondatabase/neon/issues/9901 was tested, but didn't
show consistent enough improvements to be worth it. It actually
increased the amount of work done, as there's less cache hits when
clippy runs are spread out over multiple target directories.
Additionally, parallelism makes it so caching needs to be thought about
more actively and copying around target directories to enable
parallelism eats the rest of the performance gains from parallel
execution.

After some discussion, the decision was to instead cut down on the
number of jobs that are running further. The easiest way to do this is
to not run clippy *without* default features. The list of default
features is empty for all crates, and I haven't found anything using
`cfg(feature = "default")` either, so this is likely not going to change
anything except speeding the runs up.

## Summary of changes

Reduce the amount of feature combinations tried by `cargo hack` (as
suggested in
https://github.com/neondatabase/neon/pull/8912#pullrequestreview-2286482368)
by never disabling default features.

## Alternatives

- We can split things out into different jobs which reduces the time
until everything is finished by running more things in parallel. This
does however decreases the amount of cache hits and increases the amount
of time spent on overhead tasks like repo cloning and restoring caches
by doing those multiple times instead of once.
- We could replace `cargo hack [...] clippy` with `cargo clippy [...];
cargo clippy --features testing`. I'm not 100% sure how this compares to
the change here in the PR, but it does seem to run a bit faster. That
likely means it's doing less work, but without understanding what
exactly we loose by that I'd rather not do that for now. I'd appreciate
input on this though.
2025-01-02 11:33:42 +00:00
Konstantin Knizhnik
b3cd883f93 Unlock LFC mutex when LFC cache is disabled (#10235)
## Problem

See https://github.com/neondatabase/neon/issues/10233
`lfc_containsv` returns with holding lock when LFC was disabled.
This bug was introduced in commit  78938d1b59

## Summary of changes

Release lock before return.

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2025-01-02 11:28:15 +00:00
Conrad Ludgate
38c7a2abfc chore(proxy): pre-load native tls certificates and propagate compute client config (#10182)
Now that we construct the TLS client config for cancellation as well as
connect, it feels appropriate to construct the same config once and
re-use it elsewhere. It might also help should #7500 require any extra
setup, so we can easily add it to all the appropriate call sites.
2025-01-02 09:36:13 +00:00
Conrad Ludgate
f94248a594 chore(libs/proxy): refactor tokio-postgres connection control flow (#10247)
In #10207 it was clear there was some confusion with the current
connection logic. To analyse the flow to make sure there was no poll
stalling, I ended up with the following refactor.

Notable changes:
1. Now all functions called `poll_xyz` and that have a `cx: &mut
Context` argument must return a `Poll<_>` type, and can only return
`Pending` iff an internal poll call also returned `Pending`
2. State management is handled entirely by `poll_messages`. There are
now only 2 states which makes it much easier to keep track of.

Each commit should be self-reviewable and should be simple to verify
that it keeps the same behaviour
2025-01-02 09:35:28 +00:00
34 changed files with 668 additions and 592 deletions

View File

@@ -212,7 +212,7 @@ jobs:
fi
echo "CLIPPY_COMMON_ARGS=${CLIPPY_COMMON_ARGS}" >> $GITHUB_ENV
- name: Run cargo clippy (debug)
run: cargo hack --feature-powerset clippy $CLIPPY_COMMON_ARGS
run: cargo hack --features default --ignore-unknown-features --feature-powerset clippy $CLIPPY_COMMON_ARGS
- name: Check documentation generation
run: cargo doc --workspace --no-deps --document-private-items

View File

@@ -1285,7 +1285,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \
#########################################################################################
#
# Compile and run the Neon-specific `compute_ctl` and `fast_import` binaries
# Compile the Neon-specific `compute_ctl`, `fast_import`, and `local_proxy` binaries
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools
@@ -1295,7 +1295,7 @@ ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN cd compute_tools && mold -run cargo build --locked --profile release-line-debug-size-lto
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy
#########################################################################################
#
@@ -1338,20 +1338,6 @@ RUN set -e \
&& make -j $(nproc) dist_man_MANS= \
&& make install dist_man_MANS=
#########################################################################################
#
# Compile the Neon-specific `local_proxy` binary
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS local_proxy
ARG BUILD_TAG
ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin local_proxy
#########################################################################################
#
# Layers "postgres-exporter" and "sql-exporter"
@@ -1491,7 +1477,7 @@ COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
# local_proxy and its config
COPY --from=local_proxy --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
RUN mkdir -p /etc/local_proxy && chown postgres:postgres /etc/local_proxy
# Metrics exporter binaries and configuration files

View File

@@ -1,11 +1,9 @@
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
use postgres_protocol2::message::backend::Message;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
@@ -43,7 +41,7 @@ where
let RawConnection {
stream,
parameters,
delayed_notice,
delayed_notice: _,
process_id,
secret_key,
} = connect_raw(socket, tls, config).await?;
@@ -63,13 +61,7 @@ where
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, receiver);
let connection = Connection::new(stream, parameters, receiver);
Ok((client, connection))
}

View File

@@ -1,11 +1,10 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, Stream};
use log::{info, trace};
use log::trace;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
@@ -33,10 +32,14 @@ pub struct Response {
#[derive(PartialEq, Debug)]
enum State {
Active,
Terminating,
Closing,
}
enum WriteReady {
Terminating,
WaitingOnRead,
}
/// A connection to a PostgreSQL database.
///
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
@@ -51,8 +54,7 @@ pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<BackendMessage>,
pending_responses: Option<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
@@ -64,7 +66,6 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
@@ -72,8 +73,7 @@ where
stream,
parameters,
receiver,
pending_request: None,
pending_responses,
pending_responses: None,
responses: VecDeque::new(),
state: State::Active,
}
@@ -83,7 +83,7 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
if let Some(message) = self.pending_responses.take() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
@@ -93,26 +93,22 @@ where
.map(|o| o.map(|r| r.map_err(Error::io)))
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Err(Error::closed()),
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Ok(None);
return Poll::Pending;
}
};
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Ok(Some(AsyncMessage::Notice(error)));
BackendMessage::Async(Message::NoticeResponse(_)) => {
continue;
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
@@ -120,7 +116,7 @@ where
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Ok(Some(AsyncMessage::Notification(notification)));
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
@@ -139,8 +135,10 @@ where
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
Some(Message::ErrorResponse(error)) => {
return Poll::Ready(Err(Error::db(error)))
}
_ => return Poll::Ready(Err(Error::unexpected_message())),
},
};
@@ -159,23 +157,19 @@ where
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_responses.push_back(BackendMessage::Normal {
self.pending_responses = Some(BackendMessage::Normal {
messages,
request_complete,
});
trace!("poll_read: waiting on sender");
return Ok(None);
return Poll::Pending;
}
}
}
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
if self.receiver.is_closed() {
return Poll::Ready(None);
}
@@ -193,74 +187,80 @@ where
}
}
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
/// Process client requests and write them to the postgres connection, flushing if necessary.
/// client -> postgres
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
if Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
.is_pending()
{
trace!("poll_write: waiting on socket");
return Ok(false);
// poll_ready is self-flushing.
return Poll::Pending;
}
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(RequestMessages::Single(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) if self.responses.is_empty() => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = BytesMut::new();
frontend::terminate(&mut request);
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
trace!("poll_write: done");
return Poll::Ready(Ok(WriteReady::Terminating));
}
// No more messages from the client, but there are still some responses to wait for.
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
return Ok(true);
ready!(self.poll_flush(cx))?;
return Poll::Ready(Ok(WriteReady::WaitingOnRead));
}
// Still waiting for a message from the client.
Poll::Pending => {
trace!("poll_write: waiting on request");
return Ok(true);
}
};
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
ready!(self.poll_flush(cx))?;
return Poll::Pending;
}
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
Poll::Ready(()) => {
trace!("poll_flush: flushed");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_flush: waiting on socket");
Poll::Pending
}
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
@@ -289,18 +289,30 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(WriteReady::Terminating) = closing {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
}
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
}
}
}
@@ -313,11 +325,7 @@ where
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
}
while ready!(self.poll_message(cx)?).is_some() {}
Poll::Ready(Ok(()))
}
}

View File

@@ -6,7 +6,6 @@ pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
@@ -100,10 +99,6 @@ impl Notification {
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.

View File

@@ -541,6 +541,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
LWLockRelease(lfc_lock);
return found;
}

View File

@@ -10,7 +10,6 @@ use tracing::info;
use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::config::TlsServerEndPoint;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
@@ -18,6 +17,7 @@ use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {

View File

@@ -13,7 +13,9 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
@@ -25,6 +27,7 @@ use proxy::rate_limiter::{
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::{self, GlobalConnPoolOptions};
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::types::RoleName;
use proxy::url::ApiUrl;
@@ -209,6 +212,7 @@ async fn main() -> anyhow::Result<()> {
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
&config.connect_to_compute,
Arc::new(DashMap::new()),
None,
proxy::metrics::CancellationSource::Local,
@@ -268,6 +272,12 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let compute_config = ComputeConfig {
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
timeout: Duration::from_secs(2),
};
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
metric_collection: None,
@@ -289,9 +299,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute_retry_config: RetryConfig::parse(
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
)?,
connect_to_compute: compute_config,
})))
}

View File

@@ -10,12 +10,12 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy::stream::{PqStream, Stream};
use proxy::tls::TlsServerEndPoint;
use rustls::crypto::ring;
use rustls::pki_types::PrivateKeyDer;
use tokio::io::{AsyncRead, AsyncWrite};

View File

@@ -1,6 +1,7 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use futures::future::Either;
@@ -8,7 +9,7 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::context::parquet::ParquetUploadArgs;
@@ -23,6 +24,7 @@ use proxy::redis::{elasticache, notifications};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
@@ -397,6 +399,7 @@ async fn main() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
&config.connect_to_compute,
cancel_map.clone(),
redis_publisher,
proxy::metrics::CancellationSource::FromClient,
@@ -492,6 +495,7 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -500,6 +504,7 @@ async fn main() -> anyhow::Result<()> {
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -632,6 +637,12 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
};
let compute_config = ComputeConfig {
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
timeout: Duration::from_secs(2),
};
let config = ProxyConfig {
tls_config,
metric_collection,
@@ -642,9 +653,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
connect_to_compute: compute_config,
};
let config = Box::leak(Box::new(config));

View File

@@ -3,11 +3,9 @@ use std::sync::Arc;
use dashmap::DashMap;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::CancelToken;
use pq_proto::CancelKeyData;
use rustls::crypto::ring;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
@@ -15,15 +13,15 @@ use tracing::{debug, info};
use uuid::Uuid;
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::compute::load_certs;
use crate::config::ComputeConfig;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
use crate::tls::postgres_rustls::MakeRustlsConnect;
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
@@ -35,6 +33,7 @@ type IpSubnetKey = IpNet;
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler<P> {
compute_config: &'static ComputeConfig,
map: CancelMap,
client: P,
/// This field used for the monitoring purposes.
@@ -183,7 +182,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
"cancelling query per user's request using key {key}, hostname {}, address: {}",
cancel_closure.hostname, cancel_closure.socket_addr
);
cancel_closure.try_cancel_query().await
cancel_closure.try_cancel_query(self.compute_config).await
}
#[cfg(test)]
@@ -198,8 +197,13 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
impl CancellationHandler<()> {
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
from: CancellationSource,
) -> Self {
Self {
compute_config,
map,
client: (),
from,
@@ -214,8 +218,14 @@ impl CancellationHandler<()> {
}
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
client: Option<Arc<Mutex<P>>>,
from: CancellationSource,
) -> Self {
Self {
compute_config,
map,
client,
from,
@@ -229,8 +239,6 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
}
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
@@ -257,27 +265,14 @@ impl CancelClosure {
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
pub(crate) async fn try_cancel_query(
self,
compute_config: &ComputeConfig,
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(|_e| {
CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
"TLS root store initialization failed".to_string(),
))
})?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&self.hostname,
@@ -329,11 +324,30 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
#[expect(clippy::unwrap_used)]
mod tests {
use std::time::Duration;
use super::*;
use crate::config::RetryConfig;
use crate::tls::client_config::compute_client_config_with_certs;
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::FromRedis,
));
@@ -349,8 +363,11 @@ mod tests {
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler =
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
let handler = CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::Local,
);
handler
.cancel_session(
CancelKeyData {

View File

@@ -1,16 +1,13 @@
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::crypto::ring;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
@@ -18,14 +15,15 @@ use tracing::{debug, error, info, warn};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -40,9 +38,6 @@ pub(crate) enum ConnectionError {
#[error("{COULD_NOT_CONNECT}: {0}")]
CouldNotConnect(#[from] io::Error),
#[error("Couldn't load native TLS certificates: {0:?}")]
TlsCertificateError(Vec<rustls_native_certs::Error>),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
@@ -89,7 +84,6 @@ impl ReportableError for ConnectionError {
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -199,11 +193,15 @@ impl ConnCfg {
let connect_once = |host, port| {
debug!("trying to connect to compute node at {host}:{port}");
connect_with_timeout(host, port).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
connect_with_timeout(host, port).and_then(|stream| async {
let socket_addr = stream.peer_addr()?;
let socket = socket2::SockRef::from(&stream);
// Disable Nagle's algorithm to not introduce latency between
// client and compute.
socket.set_nodelay(true)?;
// This prevents load balancer from severing the connection.
socket2::SockRef::from(&socket).set_keepalive(true)?;
Ok((socket_addr, socket))
socket.set_keepalive(true)?;
Ok((socket_addr, stream))
})
};
@@ -251,25 +249,13 @@ impl ConnCfg {
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
timeout: Duration,
config: &ComputeConfig,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(ConnectionError::TlsCertificateError)?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
@@ -341,19 +327,6 @@ fn filtered_options(options: &str) -> Option<String> {
Some(options)
}
pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
return Err(der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,17 +1,10 @@
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::AuthRateLimiter;
@@ -20,6 +13,7 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::GlobalConnPoolOptions;
pub use crate::tls::server_config::{configure_tls, TlsConfig};
use crate::types::Host;
pub struct ProxyConfig {
@@ -32,7 +26,13 @@ pub struct ProxyConfig {
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute_retry_config: RetryConfig,
pub connect_to_compute: ComputeConfig,
}
pub struct ComputeConfig {
pub retry: RetryConfig,
pub tls: Arc<rustls::ClientConfig>,
pub timeout: Duration,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
@@ -52,12 +52,6 @@ pub struct MetricCollectionConfig {
pub backup_metric_collection_config: MetricBackupCollectionConfig,
}
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
pub struct HttpConfig {
pub accept_websockets: bool,
pub pool_options: GlobalConnPoolOptions,
@@ -80,272 +74,6 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.

View File

@@ -115,7 +115,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -216,7 +216,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
config.connect_to_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -10,13 +10,13 @@ pub mod client;
pub(crate) mod errors;
use std::sync::Arc;
use std::time::Duration;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::IpPattern;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::ProjectIdInt;
@@ -73,9 +73,9 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
timeout: Duration,
config: &ComputeConfig,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config.connect(ctx, self.aux.clone(), timeout).await
self.config.connect(ctx, self.aux.clone(), config).await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {

View File

@@ -89,7 +89,6 @@ pub mod jemalloc;
pub mod logging;
pub mod metrics;
pub mod parse;
pub mod postgres_rustls;
pub mod protocol2;
pub mod proxy;
pub mod rate_limiter;
@@ -99,6 +98,7 @@ pub mod scram;
pub mod serverless;
pub mod signals;
pub mod stream;
pub mod tls;
pub mod types;
pub mod url;
pub mod usage_metrics;

View File

@@ -6,7 +6,7 @@ use tracing::{debug, info, warn};
use super::retry::ShouldRetryWakeCompute;
use crate::auth::backend::ComputeCredentialKeys;
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
use crate::config::RetryConfig;
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
@@ -19,8 +19,6 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
@@ -49,7 +47,7 @@ pub(crate) trait ConnectMechanism {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
timeout: time::Duration,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
@@ -86,11 +84,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
timeout: time::Duration,
config: &ComputeConfig,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, timeout).await)
permit.release_result(node_info.connect(ctx, config).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -105,7 +103,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBac
mechanism: &M,
user_info: &B,
wake_compute_retry_config: RetryConfig,
connect_to_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
@@ -119,10 +117,7 @@ where
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -142,7 +137,7 @@ where
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
if should_retry(&err, num_retries, compute.retry) {
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,
@@ -172,10 +167,7 @@ where
debug!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -190,7 +182,7 @@ where
return Ok(res);
}
Err(e) => {
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
if !should_retry(&e, num_retries, compute.retry) {
// Don't log an error here, caller will print the error
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
@@ -206,7 +198,7 @@ where
}
};
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
let wait_duration = retry_after(num_retries, compute.retry);
num_retries += 1;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);

View File

@@ -8,12 +8,13 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
use crate::config::{TlsConfig, PG_ALPN_PROTOCOL};
use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;
#[derive(Error, Debug)]
pub(crate) enum HandshakeError {

View File

@@ -152,7 +152,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -351,7 +351,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
config.connect_to_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -5,6 +5,7 @@ use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
@@ -67,9 +68,17 @@ pub(crate) struct ProxyPassthrough<P, S> {
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
res

View File

@@ -22,14 +22,16 @@ use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{CertResolver, RetryConfig};
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
};
use crate::error::ErrorKind;
use crate::postgres_rustls::MakeRustlsConnect;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -67,7 +69,7 @@ fn generate_certs(
}
struct ClientConfig<'a> {
config: rustls::ClientConfig,
config: Arc<rustls::ClientConfig>,
hostname: &'a str,
}
@@ -110,16 +112,7 @@ fn generate_tls_config<'a>(
};
let client_config = {
let config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.context("ring should support the default protocol versions")?
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(ca)?;
store
})
.with_no_client_auth();
let config = Arc::new(compute_client_config_with_certs([ca]));
ClientConfig { config, hostname }
};
@@ -468,7 +461,7 @@ impl ConnectMechanism for TestConnectMechanism {
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_timeout: std::time::Duration,
_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
@@ -576,6 +569,20 @@ fn helper_create_connect_info(
user_info
}
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn connect_to_compute_success() {
let _ = env_logger::try_init();
@@ -583,12 +590,8 @@ async fn connect_to_compute_success() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -601,12 +604,8 @@ async fn connect_to_compute_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -620,12 +619,8 @@ async fn connect_to_compute_non_retry_1() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -639,12 +634,8 @@ async fn connect_to_compute_non_retry_2() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -665,17 +656,13 @@ async fn connect_to_compute_non_retry_3() {
max_retries: 1,
backoff_factor: 2.0,
};
let connect_to_compute_retry_config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
let config = config();
connect_to_compute(
&ctx,
&mechanism,
&user_info,
wake_compute_retry_config,
connect_to_compute_retry_config,
&config,
)
.await
.unwrap_err();
@@ -690,12 +677,8 @@ async fn wake_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -709,12 +692,8 @@ async fn wake_non_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -12,6 +12,7 @@ use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::config::ProxyConfig;
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
@@ -39,6 +40,27 @@ pub(crate) enum Notification {
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
@@ -51,6 +73,24 @@ pub(crate) enum Notification {
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
// TODO: change type once the implementation is more fully fledged.
// See e.g. https://github.com/neondatabase/neon/pull/10073.
account_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
@@ -164,7 +204,11 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
@@ -177,6 +221,8 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
@@ -203,6 +249,15 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
password_update.role_name,
),
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
}
}
@@ -249,6 +304,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
/// Handle console's invalidation messages.
#[tracing::instrument(name = "redis_notifications", skip_all)]
pub async fn task_main<C>(
config: &'static ProxyConfig,
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
@@ -258,6 +314,7 @@ where
C: ProjectInfoCache + Send + Sync + 'static,
{
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
&config.connect_to_compute,
cancel_map,
crate::metrics::CancellationSource::FromRedis,
));

View File

@@ -13,7 +13,6 @@ use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
@@ -59,14 +58,14 @@ enum ExchangeState {
pub(crate) struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
tls_server_end_point: config::TlsServerEndPoint,
tls_server_end_point: crate::tls::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub(crate) fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: config::TlsServerEndPoint,
tls_server_end_point: crate::tls::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial(SaslInitial { nonce }),
@@ -120,7 +119,7 @@ impl SaslInitial {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &config::TlsServerEndPoint,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
let client_first_message = ClientFirstMessage::parse(input)
@@ -155,7 +154,7 @@ impl SaslSentInner {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &config::TlsServerEndPoint,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
let Self {
@@ -168,8 +167,8 @@ impl SaslSentInner {
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
})?;
// This might've been caused by a MITM attack

View File

@@ -77,11 +77,8 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let mut exchange =
Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";

View File

@@ -22,7 +22,7 @@ use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::ProxyConfig;
use crate::config::{ComputeConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
@@ -196,7 +196,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
&self.config.connect_to_compute,
)
.await
}
@@ -237,7 +237,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
&self.config.connect_to_compute,
)
.await
}
@@ -502,7 +502,7 @@ impl ConnectMechanism for TokioMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
timeout: Duration,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -511,7 +511,7 @@ impl ConnectMechanism for TokioMechanism {
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
.connect_timeout(compute_config.timeout);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(postgres_client::NoTls).await;
@@ -552,7 +552,7 @@ impl ConnectMechanism for HyperMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
timeout: Duration,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -560,7 +560,7 @@ impl ConnectMechanism for HyperMechanism {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let port = node_info.config.get_port();
let res = connect_http2(&host, port, timeout).await;
let res = connect_http2(&host, port, config.timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;

View File

@@ -124,9 +124,6 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}

View File

@@ -227,9 +227,6 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}

View File

@@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket(
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass().await {
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),

View File

@@ -11,9 +11,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::tls::TlsServerEndPoint;
/// Stream wrapper which implements libpq's protocol.
///

View File

@@ -0,0 +1,42 @@
use std::sync::Arc;
use anyhow::bail;
use rustls::crypto::ring;
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
bail!("could not parse certificates: {:?}", der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute.
/// This function is blocking.
pub fn compute_client_config_with_root_certs() -> anyhow::Result<rustls::ClientConfig> {
Ok(
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(load_certs()?)
.with_no_client_auth(),
)
}
#[cfg(test)]
pub fn compute_client_config_with_certs(
certs: impl IntoIterator<Item = rustls::pki_types::CertificateDer<'static>>,
) -> rustls::ClientConfig {
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(certs);
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(store)
.with_no_client_auth()
}

72
proxy/src/tls/mod.rs Normal file
View File

@@ -0,0 +1,72 @@
pub mod client_config;
pub mod postgres_rustls;
pub mod server_config;
use anyhow::Context;
use rustls::pki_types::CertificateDer;
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}

View File

@@ -18,7 +18,7 @@ mod private {
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use crate::config::TlsServerEndPoint;
use crate::tls::TlsServerEndPoint;
pub struct TlsConnectFuture<S> {
inner: tokio_rustls::Connect<S>,
@@ -126,16 +126,14 @@ mod private {
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
config: Arc<ClientConfig>,
pub config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: ClientConfig) -> Self {
Self {
config: Arc::new(config),
}
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
}
}

View File

@@ -0,0 +1,218 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use anyhow::{bail, Context};
use itertools::Itertools;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use super::{TlsServerEndPoint, PG_ALPN_PROTOCOL};
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}

View File

@@ -398,6 +398,7 @@ def test_timeline_archival_chaos(neon_env_builder: NeonEnvBuilder):
# Offloading is off by default at time of writing: remove this line when it's on by default
neon_env_builder.pageserver_config_override = "timeline_offloading = true"
neon_env_builder.storage_controller_config = {"heartbeat_interval": "100msec"}
neon_env_builder.enable_pageserver_remote_storage(s3_storage())
# We will exercise migrations, so need multiple pageservers