Compare commits

..

1 Commits

Author SHA1 Message Date
Em Sharnoff
8b11e3bc9c compute/sql_exporter: Bump max WSS window from 1h -> 3h
Concretely, this:

1. Changes the normal (public) exporter to add a new 3-hour working set
   size label onto the existing 5-minute, 15-minute, and 1-hour values.

2. Extends the range on the autoscaling exporter from 1..60 minutes to
   1..180 minutes -- keeping the same density, just 3x longer.
2024-12-24 17:27:48 -08:00
68 changed files with 690 additions and 1342 deletions

View File

@@ -212,7 +212,7 @@ jobs:
fi
echo "CLIPPY_COMMON_ARGS=${CLIPPY_COMMON_ARGS}" >> $GITHUB_ENV
- name: Run cargo clippy (debug)
run: cargo hack --features default --ignore-unknown-features --feature-powerset clippy $CLIPPY_COMMON_ARGS
run: cargo hack --feature-powerset clippy $CLIPPY_COMMON_ARGS
- name: Check documentation generation
run: cargo doc --workspace --no-deps --document-private-items

9
Cargo.lock generated
View File

@@ -1274,7 +1274,6 @@ dependencies = [
"chrono",
"clap",
"compute_api",
"fail",
"flate2",
"futures",
"hyper 0.14.30",
@@ -1733,9 +1732,9 @@ checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c"
[[package]]
name = "diesel"
version = "2.2.6"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccf1bedf64cdb9643204a36dd15b19a6ce8e7aa7f7b105868e9f1fad5ffa7d12"
checksum = "65e13bab2796f412722112327f3e575601a3e9cdcbe426f0d30dbf43f3f5dc71"
dependencies = [
"bitflags 2.4.1",
"byteorder",
@@ -4494,9 +4493,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "pq-sys"
version = "0.6.3"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6cc05d7ea95200187117196eee9edd0644424911821aeb28a18ce60ea0b8793"
checksum = "31c0052426df997c0cbd30789eb44ca097e3541717a7b8fa36b1c464ee7edebd"
dependencies = [
"vcpkg",
]

View File

@@ -1285,7 +1285,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \
#########################################################################################
#
# Compile the Neon-specific `compute_ctl`, `fast_import`, and `local_proxy` binaries
# Compile and run the Neon-specific `compute_ctl` and `fast_import` binaries
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools
@@ -1295,7 +1295,7 @@ ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy
RUN cd compute_tools && mold -run cargo build --locked --profile release-line-debug-size-lto
#########################################################################################
#
@@ -1338,6 +1338,20 @@ RUN set -e \
&& make -j $(nproc) dist_man_MANS= \
&& make install dist_man_MANS=
#########################################################################################
#
# Compile the Neon-specific `local_proxy` binary
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS local_proxy
ARG BUILD_TAG
ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin local_proxy
#########################################################################################
#
# Layers "postgres-exporter" and "sql-exporter"
@@ -1477,7 +1491,7 @@ COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
# local_proxy and its config
COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
COPY --from=local_proxy --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
RUN mkdir -p /etc/local_proxy && chown postgres:postgres /etc/local_proxy
# Metrics exporter binaries and configuration files

View File

@@ -1,8 +1,8 @@
-- NOTE: This is the "internal" / "machine-readable" version. This outputs the
-- working set size looking back 1..60 minutes, labeled with the number of
-- working set size looking back 1..180 minutes, labeled with the number of
-- minutes.
SELECT
x::text as duration_seconds,
neon.approximate_working_set_size_seconds(x) AS size
FROM (SELECT generate_series * 60 AS x FROM generate_series(1, 60)) AS t (x);
FROM (SELECT generate_series * 60 AS x FROM generate_series(1, 180)) AS t (x);

View File

@@ -4,5 +4,5 @@
SELECT
x AS duration,
neon.approximate_working_set_size_seconds(extract('epoch' FROM x::interval)::int) AS size FROM (
VALUES ('5m'), ('15m'), ('1h')
VALUES ('5m'), ('15m'), ('1h'), ('3h')
) AS t (x);

View File

@@ -7,7 +7,7 @@ license.workspace = true
[features]
default = []
# Enables test specific features.
testing = ["fail/failpoints"]
testing = []
[dependencies]
base64.workspace = true
@@ -19,7 +19,6 @@ camino.workspace = true
chrono.workspace = true
cfg-if.workspace = true
clap.workspace = true
fail.workspace = true
flate2.workspace = true
futures.workspace = true
hyper0 = { workspace = true, features = ["full"] }

View File

@@ -67,15 +67,12 @@ use compute_tools::params::*;
use compute_tools::spec::*;
use compute_tools::swap::resize_swap;
use rlimit::{setrlimit, Resource};
use utils::failpoint_support;
// this is an arbitrary build tag. Fine as a default / for testing purposes
// in-case of not-set environment var
const BUILD_TAG_DEFAULT: &str = "latest";
fn main() -> Result<()> {
let scenario = failpoint_support::init();
let (build_tag, clap_args) = init()?;
// enable core dumping for all child processes
@@ -103,8 +100,6 @@ fn main() -> Result<()> {
maybe_delay_exit(delay_exit);
scenario.teardown();
deinit_and_exit(wait_pg_result);
}
@@ -424,13 +419,9 @@ fn start_postgres(
"running compute with features: {:?}",
state.pspec.as_ref().unwrap().spec.features
);
// before we release the mutex, fetch some parameters for later.
let &ComputeSpec {
swap_size_bytes,
disk_quota_bytes,
disable_lfc_resizing,
..
} = &state.pspec.as_ref().unwrap().spec;
// before we release the mutex, fetch the swap size (if any) for later.
let swap_size_bytes = state.pspec.as_ref().unwrap().spec.swap_size_bytes;
let disk_quota_bytes = state.pspec.as_ref().unwrap().spec.disk_quota_bytes;
drop(state);
// Launch remaining service threads
@@ -535,18 +526,11 @@ fn start_postgres(
// This token is used internally by the monitor to clean up all threads
let token = CancellationToken::new();
// don't pass postgres connection string to vm-monitor if we don't want it to resize LFC
let pgconnstr = if disable_lfc_resizing.unwrap_or(false) {
None
} else {
file_cache_connstr.cloned()
};
let vm_monitor = rt.as_ref().map(|rt| {
rt.spawn(vm_monitor::start(
Box::leak(Box::new(vm_monitor::Args {
cgroup: cgroup.cloned(),
pgconnstr,
pgconnstr: file_cache_connstr.cloned(),
addr: vm_monitor_addr.clone(),
})),
token.clone(),

View File

@@ -1181,19 +1181,8 @@ impl ComputeNode {
let mut conf = postgres::config::Config::from(conf);
conf.application_name("compute_ctl:migrations");
match conf.connect(NoTls) {
Ok(mut client) => {
if let Err(e) = handle_migrations(&mut client) {
error!("Failed to run migrations: {}", e);
}
}
Err(e) => {
error!(
"Failed to connect to the compute for running migrations: {}",
e
);
}
};
let mut client = conf.connect(NoTls)?;
handle_migrations(&mut client).context("apply_config handle_migrations")
});
Ok::<(), anyhow::Error>(())

View File

@@ -24,11 +24,8 @@ use metrics::proto::MetricFamily;
use metrics::Encoder;
use metrics::TextEncoder;
use tokio::task;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing_utils::http::OtelName;
use utils::failpoint_support::failpoints_handler;
use utils::http::error::ApiError;
use utils::http::request::must_get_query_param;
fn status_response_from_state(state: &ComputeState) -> ComputeStatusResponse {
@@ -313,18 +310,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::POST, "/failpoints") if cfg!(feature = "testing") => {
match failpoints_handler(req, CancellationToken::new()).await {
Ok(r) => r,
Err(ApiError::BadRequest(e)) => {
render_json_error(&e.to_string(), StatusCode::BAD_REQUEST)
}
Err(_) => {
render_json_error("Internal server error", StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
// download extension files from remote extension storage on demand
(&Method::POST, route) if route.starts_with("/extension_server/") => {
info!("serving {:?} POST request", route);

View File

@@ -1,16 +1,13 @@
use anyhow::{Context, Result};
use fail::fail_point;
use postgres::Client;
use tracing::info;
/// Runs a series of migrations on a target database
pub(crate) struct MigrationRunner<'m> {
client: &'m mut Client,
migrations: &'m [&'m str],
}
impl<'m> MigrationRunner<'m> {
/// Create a new migration runner
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
assert!(migrations.len() + 1 < i64::MAX as usize);
@@ -18,7 +15,6 @@ impl<'m> MigrationRunner<'m> {
Self { client, migrations }
}
/// Get the current value neon_migration.migration_id
fn get_migration_id(&mut self) -> Result<i64> {
let query = "SELECT id FROM neon_migration.migration_id";
let row = self
@@ -29,61 +25,37 @@ impl<'m> MigrationRunner<'m> {
Ok(row.get::<&str, i64>("id"))
}
/// Update the neon_migration.migration_id value
///
/// This function has a fail point called compute-migration, which can be
/// used if you would like to fail the application of a series of migrations
/// at some point.
fn update_migration_id(&mut self, migration_id: i64) -> Result<()> {
// We use this fail point in order to check that failing in the
// middle of applying a series of migrations fails in an expected
// manner
if cfg!(feature = "testing") {
let fail = (|| {
fail_point!("compute-migration", |fail_migration_id| {
migration_id == fail_migration_id.unwrap().parse::<i64>().unwrap()
});
false
})();
if fail {
return Err(anyhow::anyhow!(format!(
"migration {} was configured to fail because of a failpoint",
migration_id
)));
}
}
let setval = format!("UPDATE neon_migration.migration_id SET id={}", migration_id);
self.client
.query(
"UPDATE neon_migration.migration_id SET id = $1",
&[&migration_id],
)
.simple_query(&setval)
.context("run_migrations update id")?;
Ok(())
}
/// Prepare the migrations the target database for handling migrations
fn prepare_database(&mut self) -> Result<()> {
self.client
.simple_query("CREATE SCHEMA IF NOT EXISTS neon_migration")?;
self.client.simple_query("CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)")?;
self.client.simple_query(
"INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING",
)?;
self.client
.simple_query("ALTER SCHEMA neon_migration OWNER TO cloud_admin")?;
self.client
.simple_query("REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC")?;
fn prepare_migrations(&mut self) -> Result<()> {
let query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
self.client.simple_query(query)?;
let query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)";
self.client.simple_query(query)?;
let query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING";
self.client.simple_query(query)?;
let query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin";
self.client.simple_query(query)?;
let query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC";
self.client.simple_query(query)?;
Ok(())
}
/// Run the configrured set of migrations
pub fn run_migrations(mut self) -> Result<()> {
self.prepare_database()?;
self.prepare_migrations()?;
let mut current_migration = self.get_migration_id()? as usize;
while current_migration < self.migrations.len() {
@@ -97,11 +69,6 @@ impl<'m> MigrationRunner<'m> {
if migration.starts_with("-- SKIP") {
info!("Skipping migration id={}", migration_id!(current_migration));
// Even though we are skipping the migration, updating the
// migration ID should help keep logic easy to understand when
// trying to understand the state of a cluster.
self.update_migration_id(migration_id!(current_migration))?;
} else {
info!(
"Running migration id={}:\n{}\n",
@@ -120,6 +87,7 @@ impl<'m> MigrationRunner<'m> {
)
})?;
// Migration IDs start at 1
self.update_migration_id(migration_id!(current_migration))?;
self.client

View File

@@ -1,9 +0,0 @@
DO $$
DECLARE
bypassrls boolean;
BEGIN
SELECT rolbypassrls INTO bypassrls FROM pg_roles WHERE rolname = 'neon_superuser';
IF NOT bypassrls THEN
RAISE EXCEPTION 'neon_superuser cannot bypass RLS';
END IF;
END $$;

View File

@@ -1,25 +0,0 @@
DO $$
DECLARE
role record;
BEGIN
FOR role IN
SELECT rolname AS name, rolinherit AS inherit
FROM pg_roles
WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
IF NOT role.inherit THEN
RAISE EXCEPTION '% cannot inherit', quote_ident(role.name);
END IF;
END LOOP;
FOR role IN
SELECT rolname AS name, rolbypassrls AS bypassrls
FROM pg_roles
WHERE NOT pg_has_role(rolname, 'neon_superuser', 'member')
AND NOT starts_with(rolname, 'pg_')
LOOP
IF role.bypassrls THEN
RAISE EXCEPTION '% can bypass RLS', quote_ident(role.name);
END IF;
END LOOP;
END $$;

View File

@@ -1,10 +0,0 @@
DO $$
BEGIN
IF (SELECT current_setting('server_version_num')::numeric < 160000) THEN
RETURN;
END IF;
IF NOT (SELECT pg_has_role('neon_superuser', 'pg_create_subscription', 'member')) THEN
RAISE EXCEPTION 'neon_superuser cannot execute pg_create_subscription';
END IF;
END $$;

View File

@@ -1,19 +0,0 @@
DO $$
DECLARE
monitor record;
BEGIN
SELECT pg_has_role('neon_superuser', 'pg_monitor', 'member') AS member,
admin_option AS admin
INTO monitor
FROM pg_auth_members
WHERE roleid = 'pg_monitor'::regrole
AND member = 'pg_monitor'::regrole;
IF NOT monitor.member THEN
RAISE EXCEPTION 'neon_superuser is not a member of pg_monitor';
END IF;
IF NOT monitor.admin THEN
RAISE EXCEPTION 'neon_superuser cannot grant pg_monitor';
END IF;
END $$;

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,13 +0,0 @@
DO $$
DECLARE
can_execute boolean;
BEGIN
SELECT bool_and(has_function_privilege('neon_superuser', oid, 'execute'))
INTO can_execute
FROM pg_proc
WHERE proname IN ('pg_export_snapshot', 'pg_log_standby_snapshot')
AND pronamespace = 'pg_catalog'::regnamespace;
IF NOT can_execute THEN
RAISE EXCEPTION 'neon_superuser cannot execute both pg_export_snapshot and pg_log_standby_snapshot';
END IF;
END $$;

View File

@@ -1,13 +0,0 @@
DO $$
DECLARE
can_execute boolean;
BEGIN
SELECT has_function_privilege('neon_superuser', oid, 'execute')
INTO can_execute
FROM pg_proc
WHERE proname = 'pg_show_replication_origin_status'
AND pronamespace = 'pg_catalog'::regnamespace;
IF NOT can_execute THEN
RAISE EXCEPTION 'neon_superuser cannot execute pg_show_replication_origin_status';
END IF;
END $$;

View File

@@ -585,7 +585,6 @@ impl Endpoint {
features: self.features.clone(),
swap_size_bytes: None,
disk_quota_bytes: None,
disable_lfc_resizing: None,
cluster: Cluster {
cluster_id: None, // project ID: not used
name: None, // project name: not used

View File

@@ -67,15 +67,6 @@ pub struct ComputeSpec {
#[serde(default)]
pub disk_quota_bytes: Option<u64>,
/// Disables the vm-monitor behavior that resizes LFC on upscale/downscale, instead relying on
/// the initial size of LFC.
///
/// This is intended for use when the LFC size is being overridden from the default but
/// autoscaling is still enabled, and we don't want the vm-monitor to interfere with the custom
/// LFC sizing.
#[serde(default)]
pub disable_lfc_resizing: Option<bool>,
/// Expected cluster state at the end of transition process.
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,

View File

@@ -33,12 +33,8 @@ pub struct Response {
#[derive(PartialEq, Debug)]
enum State {
Active,
Closing,
}
enum WriteReady {
Terminating,
WaitingOnRead,
Closing,
}
/// A connection to a PostgreSQL database.
@@ -55,6 +51,7 @@ pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
@@ -75,6 +72,7 @@ where
stream,
parameters,
receiver,
pending_request: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
@@ -95,23 +93,26 @@ where
.map(|o| o.map(|r| r.map_err(Error::io)))
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Ready(None) => return Err(Error::closed()),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Poll::Pending;
return Ok(None);
}
};
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Poll::Ready(Ok(AsyncMessage::Notice(error)));
return Ok(Some(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
@@ -119,7 +120,7 @@ where
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
return Ok(Some(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
@@ -138,10 +139,8 @@ where
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => {
return Poll::Ready(Err(Error::db(error)))
}
_ => return Poll::Ready(Err(Error::unexpected_message())),
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
},
};
@@ -165,14 +164,18 @@ where
request_complete,
});
trace!("poll_read: waiting on sender");
return Poll::Pending;
return Ok(None);
}
}
}
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
if self.receiver.is_closed() {
return Poll::Ready(None);
}
@@ -190,80 +193,74 @@ where
}
}
/// Process client requests and write them to the postgres connection, flushing if necessary.
/// client -> postgres
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
if Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
.is_pending()
{
trace!("poll_write: waiting on socket");
// poll_ready is self-flushing.
return Poll::Pending;
return Ok(false);
}
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(RequestMessages::Single(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) if self.responses.is_empty() => {
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = BytesMut::new();
frontend::terminate(&mut request);
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
trace!("poll_write: done");
return Poll::Ready(Ok(WriteReady::Terminating));
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
}
// No more messages from the client, but there are still some responses to wait for.
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
ready!(self.poll_flush(cx))?;
return Poll::Ready(Ok(WriteReady::WaitingOnRead));
return Ok(true);
}
// Still waiting for a message from the client.
Poll::Pending => {
trace!("poll_write: waiting on request");
ready!(self.poll_flush(cx))?;
return Poll::Pending;
return Ok(true);
}
};
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => {
trace!("poll_flush: flushed");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_flush: waiting on socket");
Poll::Pending
}
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
@@ -292,30 +289,18 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(WriteReady::Terminating) = closing {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
}
match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
}
}

View File

@@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::*;
/// Declare a failpoint that can use to `pause` failpoint action.
/// Declare a failpoint that can use the `pause` failpoint action.
/// We don't want to block the executor thread, hence, spawn_blocking + await.
#[macro_export]
macro_rules! pausable_failpoint {
@@ -181,7 +181,7 @@ pub async fn failpoints_handler(
) -> Result<Response<Body>, ApiError> {
if !fail::has_failpoints() {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Cannot manage failpoints because neon was compiled without failpoints support"
"Cannot manage failpoints because storage was compiled without failpoints support"
)));
}

View File

@@ -541,7 +541,6 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
LWLockRelease(lfc_lock);
return found;
}

View File

@@ -102,39 +102,23 @@ User can pass several optional headers that will affect resulting json.
2. `Neon-Array-Mode: true`. Return postgres rows as arrays instead of objects. That is more compact representation and also helps in some edge
cases where it is hard to use rows represented as objects (e.g. when several fields have the same name).
## Test proxy locally
Proxy determines project name from the subdomain, request to the `round-rice-566201.somedomain.tld` will be routed to the project named `round-rice-566201`. Unfortunately, `/etc/hosts` does not support domain wildcards, so we can use *.localtest.me` which resolves to `127.0.0.1`.
## Using SNI-based routing on localhost
Now proxy determines project name from the subdomain, request to the `round-rice-566201.somedomain.tld` will be routed to the project named `round-rice-566201`. Unfortunately, `/etc/hosts` does not support domain wildcards, so I usually use `*.localtest.me` which resolves to `127.0.0.1`. Now we can create self-signed certificate and play with proxy:
Let's create self-signed certificate by running:
```sh
openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj "/CN=*.localtest.me"
```
Then we need to build proxy with 'testing' feature and run, e.g.:
start proxy
```sh
RUST_LOG=proxy cargo run -p proxy --bin proxy --features testing -- --auth-backend postgres --auth-endpoint 'postgresql://proxy:password@endpoint.localtest.me:5432/postgres' --is-private-access-proxy true -c server.crt -k server.key
./target/debug/proxy -c server.crt -k server.key
```
We will also need to have a postgres instance. Assuming that we have setted up docker we can set it up as follows:
and connect to it
```sh
docker run \
--detach \
--name proxy-postgres \
--env POSTGRES_PASSWORD=proxy-postgres \
--publish 5432:5432 \
postgres:17-bookworm
PGSSLROOTCERT=./server.crt psql 'postgres://my-cluster-42.localtest.me:1234?sslmode=verify-full'
```
Next step is setting up auth table and schema as well as creating role (without the JWT table):
```sh
docker exec -it proxy-postgres psql -U postgres -c "CREATE SCHEMA IF NOT EXISTS neon_control_plane"
docker exec -it proxy-postgres psql -U postgres -c "CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
docker exec -it proxy-postgres psql -U postgres -c "CREATE ROLE proxy WITH SUPERUSER LOGIN PASSWORD 'password';"
```
Now from client you can start a new session:
```sh
PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.localtest.me:4432/postgres?sslmode=verify-full"
```

View File

@@ -10,6 +10,7 @@ use tracing::info;
use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::config::TlsServerEndPoint;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
@@ -17,7 +18,6 @@ use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {

View File

@@ -13,10 +13,7 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
@@ -28,7 +25,6 @@ use proxy::rate_limiter::{
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::{self, GlobalConnPoolOptions};
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::types::RoleName;
use proxy::url::ApiUrl;
@@ -37,6 +33,7 @@ project_build_tag!(BUILD_TAG);
use clap::Parser;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -166,8 +163,8 @@ async fn main() -> anyhow::Result<()> {
}
};
let metrics_listener = TokioTcpAcceptor::bind(args.metrics).await?;
let http_listener = TokioTcpAcceptor::bind(args.http).await?;
let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let shutdown = CancellationToken::new();
// todo: should scale with CU
@@ -212,7 +209,6 @@ async fn main() -> anyhow::Result<()> {
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
&config.connect_to_compute,
Arc::new(DashMap::new()),
None,
proxy::metrics::CancellationSource::Local,
@@ -272,12 +268,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let compute_config = ComputeConfig {
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
timeout: Duration::from_secs(2),
};
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
metric_collection: None,
@@ -299,7 +289,9 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,
connect_to_compute_retry_config: RetryConfig::parse(
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
)?,
})))
}

View File

@@ -10,13 +10,12 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::conn::{Acceptor, TokioTcpAcceptor};
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy::stream::{PqStream, Stream};
use proxy::tls::TlsServerEndPoint;
use rustls::crypto::ring;
use rustls::pki_types::PrivateKeyDer;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -123,7 +122,7 @@ async fn main() -> anyhow::Result<()> {
// Start listening for incoming client connections
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
info!("Starting sni router on {proxy_address}");
let proxy_listener = TokioTcpAcceptor::bind(proxy_address).await?;
let proxy_listener = TcpListener::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
@@ -153,13 +152,17 @@ async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
acceptor: TokioTcpAcceptor,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(acceptor.accept(), &cancellation_token).await
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -169,6 +172,10 @@ async fn task_main(
connections.spawn(
async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
@@ -190,7 +197,7 @@ async fn task_main(
}
connections.close();
drop(acceptor);
drop(listener);
connections.wait().await;

View File

@@ -1,7 +1,6 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use futures::future::Either;
@@ -9,10 +8,9 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
@@ -25,9 +23,9 @@ use proxy::redis::{elasticache, notifications};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -353,17 +351,17 @@ async fn main() -> anyhow::Result<()> {
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
info!("Starting http on {http_address}");
let http_listener = TokioTcpAcceptor::bind(http_address).await?;
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
let mgmt_address: SocketAddr = args.mgmt.parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TokioTcpAcceptor::bind(mgmt_address).await?;
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TokioTcpAcceptor::bind(proxy_address).await?)
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
@@ -373,7 +371,7 @@ async fn main() -> anyhow::Result<()> {
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TokioTcpAcceptor::bind(serverless_address).await?)
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
@@ -399,7 +397,6 @@ async fn main() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
&config.connect_to_compute,
cancel_map.clone(),
redis_publisher,
proxy::metrics::CancellationSource::FromClient,
@@ -495,7 +492,6 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -504,7 +500,6 @@ async fn main() -> anyhow::Result<()> {
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -637,12 +632,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
};
let compute_config = ComputeConfig {
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
tls: Arc::new(compute_client_config_with_root_certs()?),
timeout: Duration::from_secs(2),
};
let config = ProxyConfig {
tls_config,
metric_collection,
@@ -653,7 +642,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
};
let config = Box::leak(Box::new(config));

View File

@@ -3,9 +3,11 @@ use std::sync::Arc;
use dashmap::DashMap;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::CancelToken;
use pq_proto::CancelKeyData;
use rustls::crypto::ring;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
@@ -13,15 +15,15 @@ use tracing::{debug, info};
use uuid::Uuid;
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::config::ComputeConfig;
use crate::compute::load_certs;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
use crate::tls::postgres_rustls::MakeRustlsConnect;
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
@@ -33,7 +35,6 @@ type IpSubnetKey = IpNet;
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler<P> {
compute_config: &'static ComputeConfig,
map: CancelMap,
client: P,
/// This field used for the monitoring purposes.
@@ -182,7 +183,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
"cancelling query per user's request using key {key}, hostname {}, address: {}",
cancel_closure.hostname, cancel_closure.socket_addr
);
cancel_closure.try_cancel_query(self.compute_config).await
cancel_closure.try_cancel_query().await
}
#[cfg(test)]
@@ -197,13 +198,8 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
impl CancellationHandler<()> {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
from: CancellationSource,
) -> Self {
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
Self {
compute_config,
map,
client: (),
from,
@@ -218,14 +214,8 @@ impl CancellationHandler<()> {
}
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
client: Option<Arc<Mutex<P>>>,
from: CancellationSource,
) -> Self {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
Self {
compute_config,
map,
client,
from,
@@ -239,6 +229,8 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
}
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
@@ -265,14 +257,27 @@ impl CancelClosure {
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(
self,
compute_config: &ComputeConfig,
) -> Result<(), CancelError> {
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(|_e| {
CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
"TLS root store initialization failed".to_string(),
))
})?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&self.hostname,
@@ -324,30 +329,11 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
#[expect(clippy::unwrap_used)]
mod tests {
use std::time::Duration;
use super::*;
use crate::config::RetryConfig;
use crate::tls::client_config::compute_client_config_with_certs;
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::FromRedis,
));
@@ -363,11 +349,8 @@ mod tests {
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler = CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::Local,
);
let handler =
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
handler
.cancel_session(
CancelKeyData {

View File

@@ -1,13 +1,16 @@
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::crypto::ring;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
@@ -15,15 +18,14 @@ use tracing::{debug, error, info, warn};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::postgres_rustls::MakeRustlsConnect;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -38,6 +40,9 @@ pub(crate) enum ConnectionError {
#[error("{COULD_NOT_CONNECT}: {0}")]
CouldNotConnect(#[from] io::Error),
#[error("Couldn't load native TLS certificates: {0:?}")]
TlsCertificateError(Vec<rustls_native_certs::Error>),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
@@ -84,6 +89,7 @@ impl ReportableError for ConnectionError {
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -193,15 +199,11 @@ impl ConnCfg {
let connect_once = |host, port| {
debug!("trying to connect to compute node at {host}:{port}");
connect_with_timeout(host, port).and_then(|stream| async {
let socket_addr = stream.peer_addr()?;
let socket = socket2::SockRef::from(&stream);
// Disable Nagle's algorithm to not introduce latency between
// client and compute.
socket.set_nodelay(true)?;
connect_with_timeout(host, port).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
// This prevents load balancer from severing the connection.
socket.set_keepalive(true)?;
Ok((socket_addr, stream))
socket2::SockRef::from(&socket).set_keepalive(true)?;
Ok((socket_addr, socket))
})
};
@@ -249,13 +251,25 @@ impl ConnCfg {
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
config: &ComputeConfig,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
drop(pause);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(ConnectionError::TlsCertificateError)?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
@@ -327,6 +341,19 @@ fn filtered_options(options: &str) -> Option<String> {
Some(options)
}
pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
return Err(der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,10 +1,17 @@
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::AuthRateLimiter;
@@ -13,7 +20,6 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::GlobalConnPoolOptions;
pub use crate::tls::server_config::{configure_tls, TlsConfig};
use crate::types::Host;
pub struct ProxyConfig {
@@ -26,13 +32,7 @@ pub struct ProxyConfig {
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute: ComputeConfig,
}
pub struct ComputeConfig {
pub retry: RetryConfig,
pub tls: Arc<rustls::ClientConfig>,
pub timeout: Duration,
pub connect_to_compute_retry_config: RetryConfig,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
@@ -52,6 +52,12 @@ pub struct MetricCollectionConfig {
pub backup_metric_collection_config: MetricBackupCollectionConfig,
}
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
pub struct HttpConfig {
pub accept_websockets: bool,
pub pool_options: GlobalConnPoolOptions,
@@ -74,6 +80,272 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.

View File

@@ -1,221 +0,0 @@
use std::future::{poll_fn, Future};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
pub trait Acceptor {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}
fn accept(
&self,
) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> + Send;
}
pub trait Connector {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}
fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send;
}
pub struct TokioTcpAcceptor {
listener: TcpListener,
tcp_nodelay: Option<bool>,
tcp_keepalive: Option<bool>,
}
impl TokioTcpAcceptor {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
Ok(Self {
listener,
tcp_nodelay: Some(true),
tcp_keepalive: None,
})
}
pub fn into_std(self) -> io::Result<std::net::TcpListener> {
self.listener.into_std()
}
}
impl Acceptor for TokioTcpAcceptor {
type Connection = TcpStream;
type Error = io::Error;
fn accept(&self) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> {
async move {
let (stream, addr) = self.listener.accept().await?;
let socket = socket2::SockRef::from(&stream);
if let Some(nodelay) = self.tcp_nodelay {
socket.set_nodelay(nodelay)?;
}
if let Some(keepalive) = self.tcp_keepalive {
socket.set_keepalive(keepalive)?;
}
Ok((stream, addr))
}
}
}
pub struct TokioTcpConnector;
impl Connector for TokioTcpConnector {
type Connection = TcpStream;
type Error = io::Error;
fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> {
async move {
let socket = TcpStream::connect(addr).await?;
socket.set_nodelay(true)?;
Ok(socket)
}
}
}
pub trait Stream: AsyncRead + AsyncWrite + Send + Unpin + 'static {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> Stream for T {}
pub trait AsyncRead {
fn readable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_read_ready(cx))
}
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>;
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>>;
}
pub trait AsyncWrite {
fn writable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_write_ready(cx))
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>;
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>>;
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
}
impl AsyncRead for tokio::net::TcpStream {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_read_ready(self, cx)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read(Pin::new(&mut *self).get_mut(), buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read_vectored(Pin::new(&mut *self).get_mut(), bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for tokio::net::TcpStream {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_write_ready(self, cx)
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write(self, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write_vectored(self, cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_flush(self, cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_shutdown(self, cx)
}
}

View File

@@ -8,7 +8,6 @@ use tracing::{debug, error, info, Instrument};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
@@ -23,7 +22,7 @@ use crate::proxy::{
pub async fn task_main(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
acceptor: TokioTcpAcceptor,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
@@ -31,11 +30,15 @@ pub async fn task_main(
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(acceptor.accept(), &cancellation_token).await
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -112,7 +115,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -128,7 +131,7 @@ pub async fn task_main(
connections.close();
cancellations.close();
drop(acceptor);
drop(listener);
// Drain connections
connections.wait().await;
@@ -213,7 +216,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
config.connect_to_compute_retry_config,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -4,11 +4,10 @@ use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use tokio::net::TcpStream;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::control_plane::messages::{DatabaseInfo, KickSession};
use crate::waiters::{self, Waiter, Waiters};
@@ -27,15 +26,19 @@ pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), wai
/// Management API listener task.
/// It spawns management response handlers needed for the console redirect auth flow.
pub async fn task_main(acceptor: TokioTcpAcceptor) -> anyhow::Result<Infallible> {
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
scopeguard::defer! {
info!("mgmt has shut down");
}
loop {
let (socket, peer_addr) = acceptor.accept().await?;
let (socket, peer_addr) = listener.accept().await?;
info!("accepted connection from {peer_addr}");
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
let span = info_span!("mgmt", peer = %peer_addr);
tokio::task::spawn(

View File

@@ -10,13 +10,13 @@ pub mod client;
pub(crate) mod errors;
use std::sync::Arc;
use std::time::Duration;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::IpPattern;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::ProjectIdInt;
@@ -73,9 +73,9 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
config: &ComputeConfig,
timeout: Duration,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config.connect(ctx, self.aux.clone(), config).await
self.config.connect(ctx, self.aux.clone(), timeout).await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {

View File

@@ -1,4 +1,5 @@
use std::convert::Infallible;
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, bail};
@@ -13,7 +14,6 @@ use utils::http::error::ApiError;
use utils::http::json::json_response;
use utils::http::{RouterBuilder, RouterService};
use crate::conn::TokioTcpAcceptor;
use crate::ext::{LockExt, TaskExt};
use crate::jemalloc;
@@ -36,7 +36,7 @@ fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper0::Body, ApiError> {
}
pub async fn task_main(
http_acceptor: TokioTcpAcceptor,
http_listener: TcpListener,
metrics: AppMetrics,
) -> anyhow::Result<Infallible> {
scopeguard::defer! {
@@ -45,7 +45,7 @@ pub async fn task_main(
let service = || RouterService::new(make_router(metrics).build()?);
hyper0::Server::from_tcp(http_acceptor.into_std()?)?
hyper0::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;

View File

@@ -78,7 +78,6 @@ pub mod cancellation;
pub mod compute;
pub mod compute_ctl;
pub mod config;
pub mod conn;
pub mod console_redirect_proxy;
pub mod context;
pub mod control_plane;
@@ -90,6 +89,7 @@ pub mod jemalloc;
pub mod logging;
pub mod metrics;
pub mod parse;
pub mod postgres_rustls;
pub mod protocol2;
pub mod proxy;
pub mod rate_limiter;
@@ -99,7 +99,6 @@ pub mod scram;
pub mod serverless;
pub mod signals;
pub mod stream;
pub mod tls;
pub mod types;
pub mod url;
pub mod usage_metrics;

View File

@@ -18,7 +18,7 @@ mod private {
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use crate::tls::TlsServerEndPoint;
use crate::config::TlsServerEndPoint;
pub struct TlsConnectFuture<S> {
inner: tokio_rustls::Connect<S>,
@@ -126,14 +126,16 @@ mod private {
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
pub fn new(config: ClientConfig) -> Self {
Self {
config: Arc::new(config),
}
}
}

View File

@@ -6,7 +6,7 @@ use tracing::{debug, info, warn};
use super::retry::ShouldRetryWakeCompute;
use crate::auth::backend::ComputeCredentialKeys;
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::RetryConfig;
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
@@ -19,6 +19,8 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
@@ -47,7 +49,7 @@ pub(crate) trait ConnectMechanism {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
@@ -84,11 +86,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, config).await)
permit.release_result(node_info.connect(ctx, timeout).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -103,7 +105,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBac
mechanism: &M,
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
connect_to_compute_retry_config: RetryConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
@@ -117,7 +119,10 @@ where
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
let err = match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -137,7 +142,7 @@ where
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if should_retry(&err, num_retries, compute.retry) {
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,
@@ -167,7 +172,10 @@ where
debug!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism.connect_once(ctx, &node_info, compute).await {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -182,7 +190,7 @@ where
return Ok(res);
}
Err(e) => {
if !should_retry(&e, num_retries, compute.retry) {
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
// Don't log an error here, caller will print the error
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
@@ -198,7 +206,7 @@ where
}
};
let wait_duration = retry_after(num_retries, compute.retry);
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
num_retries += 1;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);

View File

@@ -8,13 +8,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
use crate::config::TlsConfig;
use crate::config::{TlsConfig, PG_ALPN_PROTOCOL};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;
#[derive(Error, Debug)]
pub(crate) enum HandshakeError {

View File

@@ -25,7 +25,6 @@ use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
@@ -56,7 +55,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
acceptor: TokioTcpAcceptor,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -65,11 +64,15 @@ pub async fn task_main(
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(acceptor.accept(), &cancellation_token).await
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -149,7 +152,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -165,7 +168,7 @@ pub async fn task_main(
connections.close();
cancellations.close();
drop(acceptor);
drop(listener);
// Drain connections
connections.wait().await;
@@ -348,7 +351,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
config.connect_to_compute_retry_config,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -5,7 +5,6 @@ use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
@@ -68,17 +67,9 @@ pub(crate) struct ProxyPassthrough<P, S> {
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
res

View File

@@ -22,16 +22,14 @@ use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{CertResolver, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
};
use crate::error::ErrorKind;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::postgres_rustls::MakeRustlsConnect;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -69,7 +67,7 @@ fn generate_certs(
}
struct ClientConfig<'a> {
config: Arc<rustls::ClientConfig>,
config: rustls::ClientConfig,
hostname: &'a str,
}
@@ -112,7 +110,16 @@ fn generate_tls_config<'a>(
};
let client_config = {
let config = Arc::new(compute_client_config_with_certs([ca]));
let config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.context("ring should support the default protocol versions")?
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(ca)?;
store
})
.with_no_client_auth();
ClientConfig { config, hostname }
};
@@ -461,7 +468,7 @@ impl ConnectMechanism for TestConnectMechanism {
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_config: &ComputeConfig,
_timeout: std::time::Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
@@ -569,20 +576,6 @@ fn helper_create_connect_info(
user_info
}
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
ComputeConfig {
retry,
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn connect_to_compute_success() {
let _ = env_logger::try_init();
@@ -590,8 +583,12 @@ async fn connect_to_compute_success() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -604,8 +601,12 @@ async fn connect_to_compute_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -619,8 +620,12 @@ async fn connect_to_compute_non_retry_1() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap_err();
mechanism.verify();
@@ -634,8 +639,12 @@ async fn connect_to_compute_non_retry_2() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -656,13 +665,17 @@ async fn connect_to_compute_non_retry_3() {
max_retries: 1,
backoff_factor: 2.0,
};
let config = config();
let connect_to_compute_retry_config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(
&ctx,
&mechanism,
&user_info,
wake_compute_retry_config,
&config,
connect_to_compute_retry_config,
)
.await
.unwrap_err();
@@ -677,8 +690,12 @@ async fn wake_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap();
mechanism.verify();
@@ -692,8 +709,12 @@ async fn wake_non_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -12,7 +12,6 @@ use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::config::ProxyConfig;
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
@@ -40,27 +39,6 @@ pub(crate) enum Notification {
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
@@ -73,24 +51,6 @@ pub(crate) enum Notification {
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
// TODO: change type once the implementation is more fully fledged.
// See e.g. https://github.com/neondatabase/neon/pull/10073.
account_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
@@ -204,11 +164,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
@@ -221,8 +177,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
@@ -249,15 +203,6 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
password_update.role_name,
),
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
}
}
@@ -304,7 +249,6 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
/// Handle console's invalidation messages.
#[tracing::instrument(name = "redis_notifications", skip_all)]
pub async fn task_main<C>(
config: &'static ProxyConfig,
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
@@ -314,7 +258,6 @@ where
C: ProjectInfoCache + Send + Sync + 'static,
{
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
&config.connect_to_compute,
cancel_map,
crate::metrics::CancellationSource::FromRedis,
));

View File

@@ -13,6 +13,7 @@ use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
@@ -58,14 +59,14 @@ enum ExchangeState {
pub(crate) struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
tls_server_end_point: crate::tls::TlsServerEndPoint,
tls_server_end_point: config::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub(crate) fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: crate::tls::TlsServerEndPoint,
tls_server_end_point: config::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial(SaslInitial { nonce }),
@@ -119,7 +120,7 @@ impl SaslInitial {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
let client_first_message = ClientFirstMessage::parse(input)
@@ -154,7 +155,7 @@ impl SaslSentInner {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &crate::tls::TlsServerEndPoint,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
let Self {
@@ -167,8 +168,8 @@ impl SaslSentInner {
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
})?;
// This might've been caused by a MITM attack

View File

@@ -77,8 +77,11 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange =
Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined);
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";

View File

@@ -22,7 +22,7 @@ use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
@@ -196,7 +196,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
self.config.connect_to_compute_retry_config,
)
.await
}
@@ -237,7 +237,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
self.config.connect_to_compute_retry_config,
)
.await
}
@@ -502,7 +502,7 @@ impl ConnectMechanism for TokioMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -511,7 +511,7 @@ impl ConnectMechanism for TokioMechanism {
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
.connect_timeout(timeout);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(postgres_client::NoTls).await;
@@ -552,7 +552,7 @@ impl ConnectMechanism for HyperMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -560,7 +560,7 @@ impl ConnectMechanism for HyperMechanism {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let port = node_info.config.get_port();
let res = connect_http2(&host, port, config.timeout).await;
let res = connect_http2(&host, port, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;

View File

@@ -35,7 +35,7 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
@@ -45,7 +45,6 @@ use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
@@ -60,7 +59,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ws_acceptor: TokioTcpAcceptor,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -135,7 +134,7 @@ pub async fn task_main(
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_acceptor.accept(), &cancellation_token).await {
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodelay: {e}");

View File

@@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket(
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),

View File

@@ -11,9 +11,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::tls::TlsServerEndPoint;
/// Stream wrapper which implements libpq's protocol.
///

View File

@@ -1,42 +0,0 @@
use std::sync::Arc;
use anyhow::bail;
use rustls::crypto::ring;
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
bail!("could not parse certificates: {:?}", der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute.
/// This function is blocking.
pub fn compute_client_config_with_root_certs() -> anyhow::Result<rustls::ClientConfig> {
Ok(
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(load_certs()?)
.with_no_client_auth(),
)
}
#[cfg(test)]
pub fn compute_client_config_with_certs(
certs: impl IntoIterator<Item = rustls::pki_types::CertificateDer<'static>>,
) -> rustls::ClientConfig {
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(certs);
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(store)
.with_no_client_auth()
}

View File

@@ -1,72 +0,0 @@
pub mod client_config;
pub mod postgres_rustls;
pub mod server_config;
use anyhow::Context;
use rustls::pki_types::CertificateDer;
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}

View File

@@ -1,218 +0,0 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use anyhow::{bail, Context};
use itertools::Itertools;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use super::{TlsServerEndPoint, PG_ALPN_PROTOCOL};
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
for entry in std::fs::read_dir(certs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
// file names aligned with default cert-manager names
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
}
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let mut config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
if allow_tls_keylogfile {
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
Ok(TlsConfig {
config: Arc::new(config),
common_names,
cert_resolver,
})
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
return None;
}
}
} else {
// No SNI, use the default certificate, otherwise we can't get to
// options parameter which can be used to set endpoint name too.
// That means that non-SNI flow will not work for CNAME domains in
// verify-full mode.
//
// If that will be a problem we can:
//
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
}
}
}

View File

@@ -43,13 +43,13 @@ scopeguard.workspace = true
strum.workspace = true
strum_macros.workspace = true
diesel = { version = "2.2.6", features = [
diesel = { version = "2.1.4", features = [
"serde_json",
"postgres",
"r2d2",
"chrono",
] }
diesel_migrations = { version = "2.2.0" }
diesel_migrations = { version = "2.1.0" }
r2d2 = { version = "0.8.10" }
utils = { path = "../libs/utils/" }

View File

@@ -3572,11 +3572,6 @@ impl Service {
.iter()
.any(|i| i.generation.is_none() || i.generation_pageserver.is_none())
{
let shard_generations = generations
.into_iter()
.map(|i| (i.tenant_shard_id, (i.generation, i.generation_pageserver)))
.collect::<HashMap<_, _>>();
// One or more shards has not been attached to a pageserver. Check if this is because it's configured
// to be detached (409: caller should give up), or because it's meant to be attached but isn't yet (503: caller should retry)
let locked = self.inner.read().unwrap();
@@ -3587,28 +3582,6 @@ impl Service {
PlacementPolicy::Attached(_) => {
// This shard is meant to be attached: the caller is not wrong to try and
// use this function, but we can't service the request right now.
let Some(generation) = shard_generations.get(shard_id) else {
// This can only happen if there is a split brain controller modifying the database. This should
// never happen when testing, and if it happens in production we can only log the issue.
debug_assert!(false);
tracing::error!("Shard {shard_id} not found in generation state! Is another rogue controller running?");
continue;
};
let (generation, generation_pageserver) = generation;
if let Some(generation) = generation {
if generation_pageserver.is_none() {
// This is legitimate only in a very narrow window where the shard was only just configured into
// Attached mode after being created in Secondary or Detached mode, and it has had its generation
// set but not yet had a Reconciler run (reconciler is the only thing that sets generation_pageserver).
tracing::warn!("Shard {shard_id} generation is set ({generation:?}) but generation_pageserver is None, reconciler not run yet?");
}
} else {
// This should never happen: a shard with no generation is only permitted when it was created in some state
// other than PlacementPolicy::Attached (and generation is always written to DB before setting Attached in memory)
debug_assert!(false);
tracing::error!("Shard {shard_id} generation is None, but it is in PlacementPolicy::Attached mode!");
continue;
}
}
PlacementPolicy::Secondary | PlacementPolicy::Detached => {
return Err(ApiError::Conflict(format!(

View File

@@ -8,7 +8,6 @@ pytest_plugins = (
"fixtures.compute_reconfigure",
"fixtures.storage_controller_proxy",
"fixtures.paths",
"fixtures.compute_migrations",
"fixtures.neon_fixtures",
"fixtures.benchmark_fixture",
"fixtures.pg_stats",

View File

@@ -1,34 +0,0 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING
import pytest
from fixtures.paths import BASE_DIR
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
COMPUTE_MIGRATIONS_DIR = BASE_DIR / "compute_tools" / "src" / "migrations"
COMPUTE_MIGRATIONS_TEST_DIR = COMPUTE_MIGRATIONS_DIR / "tests"
COMPUTE_MIGRATIONS = sorted(next(os.walk(COMPUTE_MIGRATIONS_DIR))[2])
NUM_COMPUTE_MIGRATIONS = len(COMPUTE_MIGRATIONS)
@pytest.fixture(scope="session")
def compute_migrations_dir() -> Iterator[Path]:
"""
Retrieve the path to the compute migrations directory.
"""
yield COMPUTE_MIGRATIONS_DIR
@pytest.fixture(scope="session")
def compute_migrations_test_dir() -> Iterator[Path]:
"""
Retrieve the path to the compute migrations test directory.
"""
yield COMPUTE_MIGRATIONS_TEST_DIR

View File

@@ -55,17 +55,3 @@ class EndpointHttpClient(requests.Session):
res = self.get(f"http://localhost:{self.port}/metrics")
res.raise_for_status()
return res.text
def configure_failpoints(self, *args: tuple[str, str]) -> None:
body: list[dict[str, str]] = []
for fp in args:
body.append(
{
"name": fp[0],
"action": fp[1],
}
)
res = self.post(f"http://localhost:{self.port}/failpoints", json=body)
res.raise_for_status()

View File

@@ -522,15 +522,14 @@ class NeonLocalCli(AbstractNeonCli):
safekeepers: list[int] | None = None,
remote_ext_config: str | None = None,
pageserver_id: int | None = None,
allow_multiple: bool = False,
allow_multiple=False,
basebackup_request_tries: int | None = None,
env: dict[str, str] | None = None,
) -> subprocess.CompletedProcess[str]:
args = [
"endpoint",
"start",
]
extra_env_vars = env or {}
extra_env_vars = {}
if basebackup_request_tries is not None:
extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries)
if remote_ext_config is not None:

View File

@@ -54,7 +54,6 @@ from fixtures.common_types import (
TimelineArchivalState,
TimelineId,
)
from fixtures.compute_migrations import NUM_COMPUTE_MIGRATIONS
from fixtures.endpoint.http import EndpointHttpClient
from fixtures.h2server import H2Server
from fixtures.log_helper import log
@@ -3856,7 +3855,6 @@ class Endpoint(PgProtocol, LogUtils):
safekeepers: list[int] | None = None,
allow_multiple: bool = False,
basebackup_request_tries: int | None = None,
env: dict[str, str] | None = None,
) -> Self:
"""
Start the Postgres instance.
@@ -3877,7 +3875,6 @@ class Endpoint(PgProtocol, LogUtils):
pageserver_id=pageserver_id,
allow_multiple=allow_multiple,
basebackup_request_tries=basebackup_request_tries,
env=env,
)
self._running.release(1)
self.log_config_value("shared_buffers")
@@ -3991,17 +3988,14 @@ class Endpoint(PgProtocol, LogUtils):
log.info("Updating compute spec to: %s", json.dumps(data_dict, indent=4))
json.dump(data_dict, file, indent=4)
def wait_for_migrations(self, wait_for: int = NUM_COMPUTE_MIGRATIONS) -> None:
"""
Wait for all compute migrations to be ran. Remember that migrations only
run if "pg_skip_catalog_updates" is set in the compute spec to false.
"""
# Please note: Migrations only run if pg_skip_catalog_updates is false
def wait_for_migrations(self, num_migrations: int = 11):
with self.cursor() as cur:
def check_migrations_done():
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id: int = cur.fetchall()[0][0]
assert migration_id >= wait_for
assert migration_id >= num_migrations
wait_until(check_migrations_done)

View File

@@ -21,8 +21,8 @@ if TYPE_CHECKING:
BASE_DIR = Path(__file__).parents[2]
DEFAULT_OUTPUT_DIR: str = "test_output"
COMPUTE_CONFIG_DIR = BASE_DIR / "compute" / "etc"
DEFAULT_OUTPUT_DIR: str = "test_output"
def get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str | None = None) -> Path:

View File

@@ -1,90 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, cast
import pytest
from fixtures.compute_migrations import COMPUTE_MIGRATIONS, NUM_COMPUTE_MIGRATIONS
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
def test_compute_migrations_retry(neon_simple_env: NeonEnv, compute_migrations_dir: Path):
"""
Test that compute_ctl can recover from migration failures next time it
starts, and that the persisted migration ID is correct in such cases.
"""
env = neon_simple_env
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
for i in range(1, NUM_COMPUTE_MIGRATIONS + 1):
endpoint.start(env={"FAILPOINTS": f"compute-migration=return({i})"})
# Make sure that the migrations ran
endpoint.wait_for_migrations(wait_for=i - 1)
# Confirm that we correctly recorded that in the
# neon_migration.migration_id table
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cast("int", cur.fetchall()[0][0])
assert migration_id == i - 1
endpoint.stop()
endpoint.start()
# Now wait for the rest of the migrations
endpoint.wait_for_migrations()
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cast("int", cur.fetchall()[0][0])
assert migration_id == NUM_COMPUTE_MIGRATIONS
for i, m in enumerate(COMPUTE_MIGRATIONS, start=1):
migration_query = (compute_migrations_dir / m).read_text(encoding="utf-8")
if not migration_query.startswith("-- SKIP"):
pattern = rf"Skipping migration id={i}"
else:
pattern = rf"Running migration id={i}"
endpoint.log_contains(pattern)
@pytest.mark.parametrize(
"migration",
(pytest.param((i, m), id=str(i)) for i, m in enumerate(COMPUTE_MIGRATIONS, start=1)),
)
def test_compute_migrations_e2e(
neon_simple_env: NeonEnv,
compute_migrations_dir: Path,
compute_migrations_test_dir: Path,
migration: tuple[int, str],
):
"""
Test that the migrations perform as advertised.
"""
env = neon_simple_env
migration_id = migration[0]
migration_filename = migration[1]
migration_query = (compute_migrations_dir / migration_filename).read_text(encoding="utf-8")
if migration_query.startswith("-- SKIP"):
pytest.skip("The migration is marked as SKIP")
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
# Stop applying migrations after the one we want to test, so that we can
# test the state of the cluster at the given migration ID
endpoint.start(env={"FAILPOINTS": f"compute-migration=return({migration_id + 1})"})
endpoint.wait_for_migrations(wait_for=migration_id)
check_query = (compute_migrations_test_dir / migration_filename).read_text(encoding="utf-8")
endpoint.safe_psql(check_query)

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
def test_migrations(neon_simple_env: NeonEnv):
env = neon_simple_env
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
endpoint.start()
num_migrations = 11
endpoint.wait_for_migrations(num_migrations=num_migrations)
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == num_migrations
endpoint.stop()
endpoint.start()
# We don't have a good way of knowing that the migrations code path finished executing
# in compute_ctl in the case that no migrations are being run
time.sleep(1)
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == num_migrations

View File

@@ -266,9 +266,7 @@ def test_scrubber_physical_gc_ancestors(neon_env_builder: NeonEnvBuilder, shard_
for shard in shards:
ps = env.get_tenant_pageserver(shard)
assert ps is not None
ps.http_client().timeline_compact(
shard, timeline_id, force_image_layer_creation=True, wait_until_uploaded=True
)
ps.http_client().timeline_compact(shard, timeline_id, force_image_layer_creation=True)
ps.http_client().timeline_gc(shard, timeline_id, 0)
# We will use a min_age_secs=1 threshold for deletion, let it pass

View File

@@ -398,7 +398,6 @@ def test_timeline_archival_chaos(neon_env_builder: NeonEnvBuilder):
# Offloading is off by default at time of writing: remove this line when it's on by default
neon_env_builder.pageserver_config_override = "timeline_offloading = true"
neon_env_builder.storage_controller_config = {"heartbeat_interval": "100msec"}
neon_env_builder.enable_pageserver_remote_storage(s3_storage())
# We will exercise migrations, so need multiple pageservers