Compare commits

..

1 Commits

Author SHA1 Message Date
John Spray
e09b33b49b build: link controller with system libpq 2025-01-02 12:46:51 +00:00
49 changed files with 155 additions and 750 deletions

View File

@@ -212,7 +212,7 @@ jobs:
fi
echo "CLIPPY_COMMON_ARGS=${CLIPPY_COMMON_ARGS}" >> $GITHUB_ENV
- name: Run cargo clippy (debug)
run: cargo hack --features default --ignore-unknown-features --feature-powerset clippy $CLIPPY_COMMON_ARGS
run: cargo hack --feature-powerset clippy $CLIPPY_COMMON_ARGS
- name: Check documentation generation
run: cargo doc --workspace --no-deps --document-private-items

9
Cargo.lock generated
View File

@@ -1274,7 +1274,6 @@ dependencies = [
"chrono",
"clap",
"compute_api",
"fail",
"flate2",
"futures",
"hyper 0.14.30",
@@ -1733,9 +1732,9 @@ checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c"
[[package]]
name = "diesel"
version = "2.2.6"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccf1bedf64cdb9643204a36dd15b19a6ce8e7aa7f7b105868e9f1fad5ffa7d12"
checksum = "65e13bab2796f412722112327f3e575601a3e9cdcbe426f0d30dbf43f3f5dc71"
dependencies = [
"bitflags 2.4.1",
"byteorder",
@@ -4494,9 +4493,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "pq-sys"
version = "0.6.3"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6cc05d7ea95200187117196eee9edd0644424911821aeb28a18ce60ea0b8793"
checksum = "31c0052426df997c0cbd30789eb44ca097e3541717a7b8fa36b1c464ee7edebd"
dependencies = [
"vcpkg",
]

View File

@@ -69,6 +69,8 @@ RUN set -e \
libreadline-dev \
libseccomp-dev \
ca-certificates \
# System postgres for use with client libraries (e.g. in storage controller)
postgresql-15 \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& useradd -d /data neon \
&& chown -R neon:neon /data

View File

@@ -67,8 +67,6 @@ CARGO_BUILD_FLAGS += $(filter -j1,$(MAKEFLAGS))
CARGO_CMD_PREFIX += $(if $(filter n,$(MAKEFLAGS)),,+)
# Force cargo not to print progress bar
CARGO_CMD_PREFIX += CARGO_TERM_PROGRESS_WHEN=never CI=1
# Set PQ_LIB_DIR to make sure `storage_controller` get linked with bundled libpq (through diesel)
CARGO_CMD_PREFIX += PQ_LIB_DIR=$(POSTGRES_INSTALL_DIR)/v16/lib
CACHEDIR_TAG_CONTENTS := "Signature: 8a477f597d28d172789f06886806bc55"

View File

@@ -1285,7 +1285,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \
#########################################################################################
#
# Compile the Neon-specific `compute_ctl`, `fast_import`, and `local_proxy` binaries
# Compile and run the Neon-specific `compute_ctl` and `fast_import` binaries
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools
@@ -1295,7 +1295,7 @@ ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy
RUN cd compute_tools && mold -run cargo build --locked --profile release-line-debug-size-lto
#########################################################################################
#
@@ -1338,6 +1338,20 @@ RUN set -e \
&& make -j $(nproc) dist_man_MANS= \
&& make install dist_man_MANS=
#########################################################################################
#
# Compile the Neon-specific `local_proxy` binary
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS local_proxy
ARG BUILD_TAG
ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin local_proxy
#########################################################################################
#
# Layers "postgres-exporter" and "sql-exporter"
@@ -1477,7 +1491,7 @@ COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
# local_proxy and its config
COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
COPY --from=local_proxy --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
RUN mkdir -p /etc/local_proxy && chown postgres:postgres /etc/local_proxy
# Metrics exporter binaries and configuration files

View File

@@ -7,7 +7,7 @@ license.workspace = true
[features]
default = []
# Enables test specific features.
testing = ["fail/failpoints"]
testing = []
[dependencies]
base64.workspace = true
@@ -19,7 +19,6 @@ camino.workspace = true
chrono.workspace = true
cfg-if.workspace = true
clap.workspace = true
fail.workspace = true
flate2.workspace = true
futures.workspace = true
hyper0 = { workspace = true, features = ["full"] }

View File

@@ -67,15 +67,12 @@ use compute_tools::params::*;
use compute_tools::spec::*;
use compute_tools::swap::resize_swap;
use rlimit::{setrlimit, Resource};
use utils::failpoint_support;
// this is an arbitrary build tag. Fine as a default / for testing purposes
// in-case of not-set environment var
const BUILD_TAG_DEFAULT: &str = "latest";
fn main() -> Result<()> {
let scenario = failpoint_support::init();
let (build_tag, clap_args) = init()?;
// enable core dumping for all child processes
@@ -103,8 +100,6 @@ fn main() -> Result<()> {
maybe_delay_exit(delay_exit);
scenario.teardown();
deinit_and_exit(wait_pg_result);
}
@@ -424,13 +419,9 @@ fn start_postgres(
"running compute with features: {:?}",
state.pspec.as_ref().unwrap().spec.features
);
// before we release the mutex, fetch some parameters for later.
let &ComputeSpec {
swap_size_bytes,
disk_quota_bytes,
disable_lfc_resizing,
..
} = &state.pspec.as_ref().unwrap().spec;
// before we release the mutex, fetch the swap size (if any) for later.
let swap_size_bytes = state.pspec.as_ref().unwrap().spec.swap_size_bytes;
let disk_quota_bytes = state.pspec.as_ref().unwrap().spec.disk_quota_bytes;
drop(state);
// Launch remaining service threads
@@ -535,18 +526,11 @@ fn start_postgres(
// This token is used internally by the monitor to clean up all threads
let token = CancellationToken::new();
// don't pass postgres connection string to vm-monitor if we don't want it to resize LFC
let pgconnstr = if disable_lfc_resizing.unwrap_or(false) {
None
} else {
file_cache_connstr.cloned()
};
let vm_monitor = rt.as_ref().map(|rt| {
rt.spawn(vm_monitor::start(
Box::leak(Box::new(vm_monitor::Args {
cgroup: cgroup.cloned(),
pgconnstr,
pgconnstr: file_cache_connstr.cloned(),
addr: vm_monitor_addr.clone(),
})),
token.clone(),

View File

@@ -1181,19 +1181,8 @@ impl ComputeNode {
let mut conf = postgres::config::Config::from(conf);
conf.application_name("compute_ctl:migrations");
match conf.connect(NoTls) {
Ok(mut client) => {
if let Err(e) = handle_migrations(&mut client) {
error!("Failed to run migrations: {}", e);
}
}
Err(e) => {
error!(
"Failed to connect to the compute for running migrations: {}",
e
);
}
};
let mut client = conf.connect(NoTls)?;
handle_migrations(&mut client).context("apply_config handle_migrations")
});
Ok::<(), anyhow::Error>(())

View File

@@ -24,11 +24,8 @@ use metrics::proto::MetricFamily;
use metrics::Encoder;
use metrics::TextEncoder;
use tokio::task;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing_utils::http::OtelName;
use utils::failpoint_support::failpoints_handler;
use utils::http::error::ApiError;
use utils::http::request::must_get_query_param;
fn status_response_from_state(state: &ComputeState) -> ComputeStatusResponse {
@@ -313,18 +310,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::POST, "/failpoints") if cfg!(feature = "testing") => {
match failpoints_handler(req, CancellationToken::new()).await {
Ok(r) => r,
Err(ApiError::BadRequest(e)) => {
render_json_error(&e.to_string(), StatusCode::BAD_REQUEST)
}
Err(_) => {
render_json_error("Internal server error", StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
// download extension files from remote extension storage on demand
(&Method::POST, route) if route.starts_with("/extension_server/") => {
info!("serving {:?} POST request", route);

View File

@@ -1,16 +1,13 @@
use anyhow::{Context, Result};
use fail::fail_point;
use postgres::Client;
use tracing::info;
/// Runs a series of migrations on a target database
pub(crate) struct MigrationRunner<'m> {
client: &'m mut Client,
migrations: &'m [&'m str],
}
impl<'m> MigrationRunner<'m> {
/// Create a new migration runner
pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self {
// The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64
assert!(migrations.len() + 1 < i64::MAX as usize);
@@ -18,7 +15,6 @@ impl<'m> MigrationRunner<'m> {
Self { client, migrations }
}
/// Get the current value neon_migration.migration_id
fn get_migration_id(&mut self) -> Result<i64> {
let query = "SELECT id FROM neon_migration.migration_id";
let row = self
@@ -29,61 +25,37 @@ impl<'m> MigrationRunner<'m> {
Ok(row.get::<&str, i64>("id"))
}
/// Update the neon_migration.migration_id value
///
/// This function has a fail point called compute-migration, which can be
/// used if you would like to fail the application of a series of migrations
/// at some point.
fn update_migration_id(&mut self, migration_id: i64) -> Result<()> {
// We use this fail point in order to check that failing in the
// middle of applying a series of migrations fails in an expected
// manner
if cfg!(feature = "testing") {
let fail = (|| {
fail_point!("compute-migration", |fail_migration_id| {
migration_id == fail_migration_id.unwrap().parse::<i64>().unwrap()
});
false
})();
if fail {
return Err(anyhow::anyhow!(format!(
"migration {} was configured to fail because of a failpoint",
migration_id
)));
}
}
let setval = format!("UPDATE neon_migration.migration_id SET id={}", migration_id);
self.client
.query(
"UPDATE neon_migration.migration_id SET id = $1",
&[&migration_id],
)
.simple_query(&setval)
.context("run_migrations update id")?;
Ok(())
}
/// Prepare the migrations the target database for handling migrations
fn prepare_database(&mut self) -> Result<()> {
self.client
.simple_query("CREATE SCHEMA IF NOT EXISTS neon_migration")?;
self.client.simple_query("CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)")?;
self.client.simple_query(
"INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING",
)?;
self.client
.simple_query("ALTER SCHEMA neon_migration OWNER TO cloud_admin")?;
self.client
.simple_query("REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC")?;
fn prepare_migrations(&mut self) -> Result<()> {
let query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
self.client.simple_query(query)?;
let query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)";
self.client.simple_query(query)?;
let query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING";
self.client.simple_query(query)?;
let query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin";
self.client.simple_query(query)?;
let query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC";
self.client.simple_query(query)?;
Ok(())
}
/// Run the configrured set of migrations
pub fn run_migrations(mut self) -> Result<()> {
self.prepare_database()?;
self.prepare_migrations()?;
let mut current_migration = self.get_migration_id()? as usize;
while current_migration < self.migrations.len() {
@@ -97,11 +69,6 @@ impl<'m> MigrationRunner<'m> {
if migration.starts_with("-- SKIP") {
info!("Skipping migration id={}", migration_id!(current_migration));
// Even though we are skipping the migration, updating the
// migration ID should help keep logic easy to understand when
// trying to understand the state of a cluster.
self.update_migration_id(migration_id!(current_migration))?;
} else {
info!(
"Running migration id={}:\n{}\n",
@@ -120,6 +87,7 @@ impl<'m> MigrationRunner<'m> {
)
})?;
// Migration IDs start at 1
self.update_migration_id(migration_id!(current_migration))?;
self.client

View File

@@ -1,9 +0,0 @@
DO $$
DECLARE
bypassrls boolean;
BEGIN
SELECT rolbypassrls INTO bypassrls FROM pg_roles WHERE rolname = 'neon_superuser';
IF NOT bypassrls THEN
RAISE EXCEPTION 'neon_superuser cannot bypass RLS';
END IF;
END $$;

View File

@@ -1,25 +0,0 @@
DO $$
DECLARE
role record;
BEGIN
FOR role IN
SELECT rolname AS name, rolinherit AS inherit
FROM pg_roles
WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
IF NOT role.inherit THEN
RAISE EXCEPTION '% cannot inherit', quote_ident(role.name);
END IF;
END LOOP;
FOR role IN
SELECT rolname AS name, rolbypassrls AS bypassrls
FROM pg_roles
WHERE NOT pg_has_role(rolname, 'neon_superuser', 'member')
AND NOT starts_with(rolname, 'pg_')
LOOP
IF role.bypassrls THEN
RAISE EXCEPTION '% can bypass RLS', quote_ident(role.name);
END IF;
END LOOP;
END $$;

View File

@@ -1,10 +0,0 @@
DO $$
BEGIN
IF (SELECT current_setting('server_version_num')::numeric < 160000) THEN
RETURN;
END IF;
IF NOT (SELECT pg_has_role('neon_superuser', 'pg_create_subscription', 'member')) THEN
RAISE EXCEPTION 'neon_superuser cannot execute pg_create_subscription';
END IF;
END $$;

View File

@@ -1,19 +0,0 @@
DO $$
DECLARE
monitor record;
BEGIN
SELECT pg_has_role('neon_superuser', 'pg_monitor', 'member') AS member,
admin_option AS admin
INTO monitor
FROM pg_auth_members
WHERE roleid = 'pg_monitor'::regrole
AND member = 'pg_monitor'::regrole;
IF NOT monitor.member THEN
RAISE EXCEPTION 'neon_superuser is not a member of pg_monitor';
END IF;
IF NOT monitor.admin THEN
RAISE EXCEPTION 'neon_superuser cannot grant pg_monitor';
END IF;
END $$;

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,2 +0,0 @@
-- This test was never written becuase at the time migration tests were added
-- the accompanying migration was already skipped.

View File

@@ -1,13 +0,0 @@
DO $$
DECLARE
can_execute boolean;
BEGIN
SELECT bool_and(has_function_privilege('neon_superuser', oid, 'execute'))
INTO can_execute
FROM pg_proc
WHERE proname IN ('pg_export_snapshot', 'pg_log_standby_snapshot')
AND pronamespace = 'pg_catalog'::regnamespace;
IF NOT can_execute THEN
RAISE EXCEPTION 'neon_superuser cannot execute both pg_export_snapshot and pg_log_standby_snapshot';
END IF;
END $$;

View File

@@ -1,13 +0,0 @@
DO $$
DECLARE
can_execute boolean;
BEGIN
SELECT has_function_privilege('neon_superuser', oid, 'execute')
INTO can_execute
FROM pg_proc
WHERE proname = 'pg_show_replication_origin_status'
AND pronamespace = 'pg_catalog'::regnamespace;
IF NOT can_execute THEN
RAISE EXCEPTION 'neon_superuser cannot execute pg_show_replication_origin_status';
END IF;
END $$;

View File

@@ -585,7 +585,6 @@ impl Endpoint {
features: self.features.clone(),
swap_size_bytes: None,
disk_quota_bytes: None,
disable_lfc_resizing: None,
cluster: Cluster {
cluster_id: None, // project ID: not used
name: None, // project name: not used

View File

@@ -67,15 +67,6 @@ pub struct ComputeSpec {
#[serde(default)]
pub disk_quota_bytes: Option<u64>,
/// Disables the vm-monitor behavior that resizes LFC on upscale/downscale, instead relying on
/// the initial size of LFC.
///
/// This is intended for use when the LFC size is being overridden from the default but
/// autoscaling is still enabled, and we don't want the vm-monitor to interfere with the custom
/// LFC sizing.
#[serde(default)]
pub disable_lfc_resizing: Option<bool>,
/// Expected cluster state at the end of transition process.
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,

View File

@@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::*;
/// Declare a failpoint that can use to `pause` failpoint action.
/// Declare a failpoint that can use the `pause` failpoint action.
/// We don't want to block the executor thread, hence, spawn_blocking + await.
#[macro_export]
macro_rules! pausable_failpoint {
@@ -181,7 +181,7 @@ pub async fn failpoints_handler(
) -> Result<Response<Body>, ApiError> {
if !fail::has_failpoints() {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Cannot manage failpoints because neon was compiled without failpoints support"
"Cannot manage failpoints because storage was compiled without failpoints support"
)));
}

View File

@@ -102,39 +102,23 @@ User can pass several optional headers that will affect resulting json.
2. `Neon-Array-Mode: true`. Return postgres rows as arrays instead of objects. That is more compact representation and also helps in some edge
cases where it is hard to use rows represented as objects (e.g. when several fields have the same name).
## Test proxy locally
Proxy determines project name from the subdomain, request to the `round-rice-566201.somedomain.tld` will be routed to the project named `round-rice-566201`. Unfortunately, `/etc/hosts` does not support domain wildcards, so we can use *.localtest.me` which resolves to `127.0.0.1`.
## Using SNI-based routing on localhost
Now proxy determines project name from the subdomain, request to the `round-rice-566201.somedomain.tld` will be routed to the project named `round-rice-566201`. Unfortunately, `/etc/hosts` does not support domain wildcards, so I usually use `*.localtest.me` which resolves to `127.0.0.1`. Now we can create self-signed certificate and play with proxy:
Let's create self-signed certificate by running:
```sh
openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj "/CN=*.localtest.me"
```
Then we need to build proxy with 'testing' feature and run, e.g.:
start proxy
```sh
RUST_LOG=proxy cargo run -p proxy --bin proxy --features testing -- --auth-backend postgres --auth-endpoint 'postgresql://proxy:password@endpoint.localtest.me:5432/postgres' --is-private-access-proxy true -c server.crt -k server.key
./target/debug/proxy -c server.crt -k server.key
```
We will also need to have a postgres instance. Assuming that we have setted up docker we can set it up as follows:
and connect to it
```sh
docker run \
--detach \
--name proxy-postgres \
--env POSTGRES_PASSWORD=proxy-postgres \
--publish 5432:5432 \
postgres:17-bookworm
PGSSLROOTCERT=./server.crt psql 'postgres://my-cluster-42.localtest.me:1234?sslmode=verify-full'
```
Next step is setting up auth table and schema as well as creating role (without the JWT table):
```sh
docker exec -it proxy-postgres psql -U postgres -c "CREATE SCHEMA IF NOT EXISTS neon_control_plane"
docker exec -it proxy-postgres psql -U postgres -c "CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
docker exec -it proxy-postgres psql -U postgres -c "CREATE ROLE proxy WITH SUPERUSER LOGIN PASSWORD 'password';"
```
Now from client you can start a new session:
```sh
PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.localtest.me:4432/postgres?sslmode=verify-full"
```

View File

@@ -16,7 +16,6 @@ use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
@@ -37,6 +36,7 @@ project_build_tag!(BUILD_TAG);
use clap::Parser;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -166,8 +166,8 @@ async fn main() -> anyhow::Result<()> {
}
};
let metrics_listener = TokioTcpAcceptor::bind(args.metrics).await?;
let http_listener = TokioTcpAcceptor::bind(args.http).await?;
let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let shutdown = CancellationToken::new();
// todo: should scale with CU

View File

@@ -10,7 +10,6 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::conn::{Acceptor, TokioTcpAcceptor};
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
@@ -123,7 +122,7 @@ async fn main() -> anyhow::Result<()> {
// Start listening for incoming client connections
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
info!("Starting sni router on {proxy_address}");
let proxy_listener = TokioTcpAcceptor::bind(proxy_address).await?;
let proxy_listener = TcpListener::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
@@ -153,13 +152,17 @@ async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
acceptor: TokioTcpAcceptor,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(acceptor.accept(), &cancellation_token).await
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -169,6 +172,10 @@ async fn task_main(
connections.spawn(
async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
@@ -190,7 +197,7 @@ async fn task_main(
}
connections.close();
drop(acceptor);
drop(listener);
connections.wait().await;

View File

@@ -12,7 +12,6 @@ use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
@@ -28,6 +27,7 @@ use proxy::serverless::GlobalConnPoolOptions;
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -353,17 +353,17 @@ async fn main() -> anyhow::Result<()> {
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
info!("Starting http on {http_address}");
let http_listener = TokioTcpAcceptor::bind(http_address).await?;
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
let mgmt_address: SocketAddr = args.mgmt.parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TokioTcpAcceptor::bind(mgmt_address).await?;
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TokioTcpAcceptor::bind(proxy_address).await?)
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
@@ -373,7 +373,7 @@ async fn main() -> anyhow::Result<()> {
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TokioTcpAcceptor::bind(serverless_address).await?)
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {

View File

@@ -193,15 +193,11 @@ impl ConnCfg {
let connect_once = |host, port| {
debug!("trying to connect to compute node at {host}:{port}");
connect_with_timeout(host, port).and_then(|stream| async {
let socket_addr = stream.peer_addr()?;
let socket = socket2::SockRef::from(&stream);
// Disable Nagle's algorithm to not introduce latency between
// client and compute.
socket.set_nodelay(true)?;
connect_with_timeout(host, port).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
// This prevents load balancer from severing the connection.
socket.set_keepalive(true)?;
Ok((socket_addr, stream))
socket2::SockRef::from(&socket).set_keepalive(true)?;
Ok((socket_addr, socket))
})
};

View File

@@ -1,221 +0,0 @@
use std::future::{poll_fn, Future};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
pub trait Acceptor {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}
fn accept(
&self,
) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> + Send;
}
pub trait Connector {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}
fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send;
}
pub struct TokioTcpAcceptor {
listener: TcpListener,
tcp_nodelay: Option<bool>,
tcp_keepalive: Option<bool>,
}
impl TokioTcpAcceptor {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
Ok(Self {
listener,
tcp_nodelay: Some(true),
tcp_keepalive: None,
})
}
pub fn into_std(self) -> io::Result<std::net::TcpListener> {
self.listener.into_std()
}
}
impl Acceptor for TokioTcpAcceptor {
type Connection = TcpStream;
type Error = io::Error;
fn accept(&self) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> {
async move {
let (stream, addr) = self.listener.accept().await?;
let socket = socket2::SockRef::from(&stream);
if let Some(nodelay) = self.tcp_nodelay {
socket.set_nodelay(nodelay)?;
}
if let Some(keepalive) = self.tcp_keepalive {
socket.set_keepalive(keepalive)?;
}
Ok((stream, addr))
}
}
}
pub struct TokioTcpConnector;
impl Connector for TokioTcpConnector {
type Connection = TcpStream;
type Error = io::Error;
fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> {
async move {
let socket = TcpStream::connect(addr).await?;
socket.set_nodelay(true)?;
Ok(socket)
}
}
}
pub trait Stream: AsyncRead + AsyncWrite + Send + Unpin + 'static {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> Stream for T {}
pub trait AsyncRead {
fn readable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_read_ready(cx))
}
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>;
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>>;
}
pub trait AsyncWrite {
fn writable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_write_ready(cx))
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>;
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>>;
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
}
impl AsyncRead for tokio::net::TcpStream {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_read_ready(self, cx)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read(Pin::new(&mut *self).get_mut(), buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read_vectored(Pin::new(&mut *self).get_mut(), bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for tokio::net::TcpStream {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_write_ready(self, cx)
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write(self, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write_vectored(self, cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_flush(self, cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_shutdown(self, cx)
}
}

View File

@@ -8,7 +8,6 @@ use tracing::{debug, error, info, Instrument};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
@@ -23,7 +22,7 @@ use crate::proxy::{
pub async fn task_main(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
acceptor: TokioTcpAcceptor,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
@@ -31,11 +30,15 @@ pub async fn task_main(
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(acceptor.accept(), &cancellation_token).await
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -128,7 +131,7 @@ pub async fn task_main(
connections.close();
cancellations.close();
drop(acceptor);
drop(listener);
// Drain connections
connections.wait().await;

View File

@@ -4,11 +4,10 @@ use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use tokio::net::TcpStream;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::control_plane::messages::{DatabaseInfo, KickSession};
use crate::waiters::{self, Waiter, Waiters};
@@ -27,15 +26,19 @@ pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), wai
/// Management API listener task.
/// It spawns management response handlers needed for the console redirect auth flow.
pub async fn task_main(acceptor: TokioTcpAcceptor) -> anyhow::Result<Infallible> {
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
scopeguard::defer! {
info!("mgmt has shut down");
}
loop {
let (socket, peer_addr) = acceptor.accept().await?;
let (socket, peer_addr) = listener.accept().await?;
info!("accepted connection from {peer_addr}");
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
let span = info_span!("mgmt", peer = %peer_addr);
tokio::task::spawn(

View File

@@ -1,4 +1,5 @@
use std::convert::Infallible;
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, bail};
@@ -13,7 +14,6 @@ use utils::http::error::ApiError;
use utils::http::json::json_response;
use utils::http::{RouterBuilder, RouterService};
use crate::conn::TokioTcpAcceptor;
use crate::ext::{LockExt, TaskExt};
use crate::jemalloc;
@@ -36,7 +36,7 @@ fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper0::Body, ApiError> {
}
pub async fn task_main(
http_acceptor: TokioTcpAcceptor,
http_listener: TcpListener,
metrics: AppMetrics,
) -> anyhow::Result<Infallible> {
scopeguard::defer! {
@@ -45,7 +45,7 @@ pub async fn task_main(
let service = || RouterService::new(make_router(metrics).build()?);
hyper0::Server::from_tcp(http_acceptor.into_std()?)?
hyper0::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;

View File

@@ -78,7 +78,6 @@ pub mod cancellation;
pub mod compute;
pub mod compute_ctl;
pub mod config;
pub mod conn;
pub mod console_redirect_proxy;
pub mod context;
pub mod control_plane;

View File

@@ -25,7 +25,6 @@ use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
@@ -56,7 +55,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
acceptor: TokioTcpAcceptor,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -65,11 +64,15 @@ pub async fn task_main(
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(acceptor.accept(), &cancellation_token).await
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -165,7 +168,7 @@ pub async fn task_main(
connections.close();
cancellations.close();
drop(acceptor);
drop(listener);
// Drain connections
connections.wait().await;

View File

@@ -40,27 +40,6 @@ pub(crate) enum Notification {
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
},
#[serde(
rename = "/allowed_vpc_endpoints_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
@@ -73,24 +52,6 @@ pub(crate) enum Notification {
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdated {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
// TODO: change type once the implementation is more fully fledged.
// See e.g. https://github.com/neondatabase/neon/pull/10073.
account_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
@@ -204,11 +165,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::BlockPublicOrVpcAccessUpdated { .. }
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
@@ -221,8 +178,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
}
// TODO: add additional metrics for the other event types.
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
@@ -249,15 +204,6 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
password_update.role_name,
),
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
// https://github.com/neondatabase/neon/pull/10073
}
}
}

View File

@@ -35,7 +35,7 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
@@ -45,7 +45,6 @@ use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
@@ -60,7 +59,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ws_acceptor: TokioTcpAcceptor,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -135,7 +134,7 @@ pub async fn task_main(
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_acceptor.accept(), &cancellation_token).await {
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodelay: {e}");

View File

@@ -43,13 +43,13 @@ scopeguard.workspace = true
strum.workspace = true
strum_macros.workspace = true
diesel = { version = "2.2.6", features = [
diesel = { version = "2.1.4", features = [
"serde_json",
"postgres",
"r2d2",
"chrono",
] }
diesel_migrations = { version = "2.2.0" }
diesel_migrations = { version = "2.1.0" }
r2d2 = { version = "0.8.10" }
utils = { path = "../libs/utils/" }

View File

@@ -3572,11 +3572,6 @@ impl Service {
.iter()
.any(|i| i.generation.is_none() || i.generation_pageserver.is_none())
{
let shard_generations = generations
.into_iter()
.map(|i| (i.tenant_shard_id, (i.generation, i.generation_pageserver)))
.collect::<HashMap<_, _>>();
// One or more shards has not been attached to a pageserver. Check if this is because it's configured
// to be detached (409: caller should give up), or because it's meant to be attached but isn't yet (503: caller should retry)
let locked = self.inner.read().unwrap();
@@ -3587,28 +3582,6 @@ impl Service {
PlacementPolicy::Attached(_) => {
// This shard is meant to be attached: the caller is not wrong to try and
// use this function, but we can't service the request right now.
let Some(generation) = shard_generations.get(shard_id) else {
// This can only happen if there is a split brain controller modifying the database. This should
// never happen when testing, and if it happens in production we can only log the issue.
debug_assert!(false);
tracing::error!("Shard {shard_id} not found in generation state! Is another rogue controller running?");
continue;
};
let (generation, generation_pageserver) = generation;
if let Some(generation) = generation {
if generation_pageserver.is_none() {
// This is legitimate only in a very narrow window where the shard was only just configured into
// Attached mode after being created in Secondary or Detached mode, and it has had its generation
// set but not yet had a Reconciler run (reconciler is the only thing that sets generation_pageserver).
tracing::warn!("Shard {shard_id} generation is set ({generation:?}) but generation_pageserver is None, reconciler not run yet?");
}
} else {
// This should never happen: a shard with no generation is only permitted when it was created in some state
// other than PlacementPolicy::Attached (and generation is always written to DB before setting Attached in memory)
debug_assert!(false);
tracing::error!("Shard {shard_id} generation is None, but it is in PlacementPolicy::Attached mode!");
continue;
}
}
PlacementPolicy::Secondary | PlacementPolicy::Detached => {
return Err(ApiError::Conflict(format!(

View File

@@ -8,7 +8,6 @@ pytest_plugins = (
"fixtures.compute_reconfigure",
"fixtures.storage_controller_proxy",
"fixtures.paths",
"fixtures.compute_migrations",
"fixtures.neon_fixtures",
"fixtures.benchmark_fixture",
"fixtures.pg_stats",

View File

@@ -1,34 +0,0 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING
import pytest
from fixtures.paths import BASE_DIR
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
COMPUTE_MIGRATIONS_DIR = BASE_DIR / "compute_tools" / "src" / "migrations"
COMPUTE_MIGRATIONS_TEST_DIR = COMPUTE_MIGRATIONS_DIR / "tests"
COMPUTE_MIGRATIONS = sorted(next(os.walk(COMPUTE_MIGRATIONS_DIR))[2])
NUM_COMPUTE_MIGRATIONS = len(COMPUTE_MIGRATIONS)
@pytest.fixture(scope="session")
def compute_migrations_dir() -> Iterator[Path]:
"""
Retrieve the path to the compute migrations directory.
"""
yield COMPUTE_MIGRATIONS_DIR
@pytest.fixture(scope="session")
def compute_migrations_test_dir() -> Iterator[Path]:
"""
Retrieve the path to the compute migrations test directory.
"""
yield COMPUTE_MIGRATIONS_TEST_DIR

View File

@@ -55,17 +55,3 @@ class EndpointHttpClient(requests.Session):
res = self.get(f"http://localhost:{self.port}/metrics")
res.raise_for_status()
return res.text
def configure_failpoints(self, *args: tuple[str, str]) -> None:
body: list[dict[str, str]] = []
for fp in args:
body.append(
{
"name": fp[0],
"action": fp[1],
}
)
res = self.post(f"http://localhost:{self.port}/failpoints", json=body)
res.raise_for_status()

View File

@@ -522,15 +522,14 @@ class NeonLocalCli(AbstractNeonCli):
safekeepers: list[int] | None = None,
remote_ext_config: str | None = None,
pageserver_id: int | None = None,
allow_multiple: bool = False,
allow_multiple=False,
basebackup_request_tries: int | None = None,
env: dict[str, str] | None = None,
) -> subprocess.CompletedProcess[str]:
args = [
"endpoint",
"start",
]
extra_env_vars = env or {}
extra_env_vars = {}
if basebackup_request_tries is not None:
extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries)
if remote_ext_config is not None:

View File

@@ -54,7 +54,6 @@ from fixtures.common_types import (
TimelineArchivalState,
TimelineId,
)
from fixtures.compute_migrations import NUM_COMPUTE_MIGRATIONS
from fixtures.endpoint.http import EndpointHttpClient
from fixtures.h2server import H2Server
from fixtures.log_helper import log
@@ -3856,7 +3855,6 @@ class Endpoint(PgProtocol, LogUtils):
safekeepers: list[int] | None = None,
allow_multiple: bool = False,
basebackup_request_tries: int | None = None,
env: dict[str, str] | None = None,
) -> Self:
"""
Start the Postgres instance.
@@ -3877,7 +3875,6 @@ class Endpoint(PgProtocol, LogUtils):
pageserver_id=pageserver_id,
allow_multiple=allow_multiple,
basebackup_request_tries=basebackup_request_tries,
env=env,
)
self._running.release(1)
self.log_config_value("shared_buffers")
@@ -3991,17 +3988,14 @@ class Endpoint(PgProtocol, LogUtils):
log.info("Updating compute spec to: %s", json.dumps(data_dict, indent=4))
json.dump(data_dict, file, indent=4)
def wait_for_migrations(self, wait_for: int = NUM_COMPUTE_MIGRATIONS) -> None:
"""
Wait for all compute migrations to be ran. Remember that migrations only
run if "pg_skip_catalog_updates" is set in the compute spec to false.
"""
# Please note: Migrations only run if pg_skip_catalog_updates is false
def wait_for_migrations(self, num_migrations: int = 11):
with self.cursor() as cur:
def check_migrations_done():
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id: int = cur.fetchall()[0][0]
assert migration_id >= wait_for
assert migration_id >= num_migrations
wait_until(check_migrations_done)

View File

@@ -21,8 +21,8 @@ if TYPE_CHECKING:
BASE_DIR = Path(__file__).parents[2]
DEFAULT_OUTPUT_DIR: str = "test_output"
COMPUTE_CONFIG_DIR = BASE_DIR / "compute" / "etc"
DEFAULT_OUTPUT_DIR: str = "test_output"
def get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str | None = None) -> Path:

View File

@@ -1,90 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, cast
import pytest
from fixtures.compute_migrations import COMPUTE_MIGRATIONS, NUM_COMPUTE_MIGRATIONS
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
def test_compute_migrations_retry(neon_simple_env: NeonEnv, compute_migrations_dir: Path):
"""
Test that compute_ctl can recover from migration failures next time it
starts, and that the persisted migration ID is correct in such cases.
"""
env = neon_simple_env
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
for i in range(1, NUM_COMPUTE_MIGRATIONS + 1):
endpoint.start(env={"FAILPOINTS": f"compute-migration=return({i})"})
# Make sure that the migrations ran
endpoint.wait_for_migrations(wait_for=i - 1)
# Confirm that we correctly recorded that in the
# neon_migration.migration_id table
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cast("int", cur.fetchall()[0][0])
assert migration_id == i - 1
endpoint.stop()
endpoint.start()
# Now wait for the rest of the migrations
endpoint.wait_for_migrations()
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cast("int", cur.fetchall()[0][0])
assert migration_id == NUM_COMPUTE_MIGRATIONS
for i, m in enumerate(COMPUTE_MIGRATIONS, start=1):
migration_query = (compute_migrations_dir / m).read_text(encoding="utf-8")
if not migration_query.startswith("-- SKIP"):
pattern = rf"Skipping migration id={i}"
else:
pattern = rf"Running migration id={i}"
endpoint.log_contains(pattern)
@pytest.mark.parametrize(
"migration",
(pytest.param((i, m), id=str(i)) for i, m in enumerate(COMPUTE_MIGRATIONS, start=1)),
)
def test_compute_migrations_e2e(
neon_simple_env: NeonEnv,
compute_migrations_dir: Path,
compute_migrations_test_dir: Path,
migration: tuple[int, str],
):
"""
Test that the migrations perform as advertised.
"""
env = neon_simple_env
migration_id = migration[0]
migration_filename = migration[1]
migration_query = (compute_migrations_dir / migration_filename).read_text(encoding="utf-8")
if migration_query.startswith("-- SKIP"):
pytest.skip("The migration is marked as SKIP")
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
# Stop applying migrations after the one we want to test, so that we can
# test the state of the cluster at the given migration ID
endpoint.start(env={"FAILPOINTS": f"compute-migration=return({migration_id + 1})"})
endpoint.wait_for_migrations(wait_for=migration_id)
check_query = (compute_migrations_test_dir / migration_filename).read_text(encoding="utf-8")
endpoint.safe_psql(check_query)

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
def test_migrations(neon_simple_env: NeonEnv):
env = neon_simple_env
endpoint = env.endpoints.create("main")
endpoint.respec(skip_pg_catalog_updates=False)
endpoint.start()
num_migrations = 11
endpoint.wait_for_migrations(num_migrations=num_migrations)
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == num_migrations
endpoint.stop()
endpoint.start()
# We don't have a good way of knowing that the migrations code path finished executing
# in compute_ctl in the case that no migrations are being run
time.sleep(1)
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == num_migrations

View File

@@ -266,9 +266,7 @@ def test_scrubber_physical_gc_ancestors(neon_env_builder: NeonEnvBuilder, shard_
for shard in shards:
ps = env.get_tenant_pageserver(shard)
assert ps is not None
ps.http_client().timeline_compact(
shard, timeline_id, force_image_layer_creation=True, wait_until_uploaded=True
)
ps.http_client().timeline_compact(shard, timeline_id, force_image_layer_creation=True)
ps.http_client().timeline_gc(shard, timeline_id, 0)
# We will use a min_age_secs=1 threshold for deletion, let it pass

View File

@@ -398,7 +398,6 @@ def test_timeline_archival_chaos(neon_env_builder: NeonEnvBuilder):
# Offloading is off by default at time of writing: remove this line when it's on by default
neon_env_builder.pageserver_config_override = "timeline_offloading = true"
neon_env_builder.storage_controller_config = {"heartbeat_interval": "100msec"}
neon_env_builder.enable_pageserver_remote_storage(s3_storage())
# We will exercise migrations, so need multiple pageservers