mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
Merge branch 'main' into ephemeralsad/graceful-draining
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1445,6 +1445,7 @@ dependencies = [
|
||||
"regex",
|
||||
"reqwest",
|
||||
"safekeeper_api",
|
||||
"safekeeper_client",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -2054,6 +2055,7 @@ dependencies = [
|
||||
"axum-extra",
|
||||
"camino",
|
||||
"camino-tempfile",
|
||||
"clap",
|
||||
"futures",
|
||||
"http-body-util",
|
||||
"itertools 0.10.5",
|
||||
|
||||
13
Dockerfile
13
Dockerfile
@@ -110,6 +110,19 @@ RUN set -e \
|
||||
# System postgres for use with client libraries (e.g. in storage controller)
|
||||
postgresql-15 \
|
||||
openssl \
|
||||
unzip \
|
||||
curl \
|
||||
&& ARCH=$(uname -m) \
|
||||
&& if [ "$ARCH" = "x86_64" ]; then \
|
||||
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \
|
||||
elif [ "$ARCH" = "aarch64" ]; then \
|
||||
curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $ARCH" && exit 1; \
|
||||
fi \
|
||||
&& unzip awscliv2.zip \
|
||||
&& ./aws/install \
|
||||
&& rm -rf aws awscliv2.zip \
|
||||
&& rm -f /etc/apt/apt.conf.d/80-retries \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
|
||||
&& useradd -d /data neon \
|
||||
|
||||
@@ -36,6 +36,7 @@ pageserver_api.workspace = true
|
||||
pageserver_client.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
safekeeper_api.workspace = true
|
||||
safekeeper_client.workspace = true
|
||||
postgres_connection.workspace = true
|
||||
storage_broker.workspace = true
|
||||
http-utils.workspace = true
|
||||
|
||||
@@ -45,7 +45,7 @@ use pageserver_api::models::{
|
||||
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
|
||||
use postgres_backend::AuthType;
|
||||
use postgres_connection::parse_host_port;
|
||||
use safekeeper_api::membership::SafekeeperGeneration;
|
||||
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
|
||||
use safekeeper_api::{
|
||||
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
|
||||
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
|
||||
@@ -1255,6 +1255,45 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
|
||||
pageserver
|
||||
.timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version)
|
||||
.await?;
|
||||
if env.storage_controller.timelines_onto_safekeepers {
|
||||
println!("Creating timeline on safekeeper ...");
|
||||
let timeline_info = pageserver
|
||||
.timeline_info(
|
||||
TenantShardId::unsharded(tenant_id),
|
||||
timeline_id,
|
||||
pageserver_client::mgmt_api::ForceAwaitLogicalSize::No,
|
||||
)
|
||||
.await?;
|
||||
let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap());
|
||||
let default_host = default_sk
|
||||
.conf
|
||||
.listen_addr
|
||||
.clone()
|
||||
.unwrap_or_else(|| "localhost".to_string());
|
||||
let mconf = safekeeper_api::membership::Configuration {
|
||||
generation: SafekeeperGeneration::new(1),
|
||||
members: safekeeper_api::membership::MemberSet {
|
||||
m: vec![SafekeeperId {
|
||||
host: default_host,
|
||||
id: default_sk.conf.id,
|
||||
pg_port: default_sk.conf.pg_port,
|
||||
}],
|
||||
},
|
||||
new_members: None,
|
||||
};
|
||||
let pg_version = args.pg_version * 10000;
|
||||
let req = safekeeper_api::models::TimelineCreateRequest {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
mconf,
|
||||
pg_version,
|
||||
system_id: None,
|
||||
wal_seg_size: None,
|
||||
start_lsn: timeline_info.last_record_lsn,
|
||||
commit_lsn: None,
|
||||
};
|
||||
default_sk.create_timeline(&req).await?;
|
||||
}
|
||||
env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?;
|
||||
println!("Done");
|
||||
}
|
||||
|
||||
@@ -635,4 +635,16 @@ impl PageServerNode {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
pub async fn timeline_info(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
timeline_id: TimelineId,
|
||||
force_await_logical_size: mgmt_api::ForceAwaitLogicalSize,
|
||||
) -> anyhow::Result<TimelineInfo> {
|
||||
let timeline_info = self
|
||||
.http_client
|
||||
.timeline_info(tenant_shard_id, timeline_id, force_await_logical_size)
|
||||
.await?;
|
||||
Ok(timeline_info)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
//! .neon/safekeepers/<safekeeper id>
|
||||
//! ```
|
||||
use std::error::Error as _;
|
||||
use std::future::Future;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
@@ -14,9 +13,9 @@ use std::{io, result};
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8PathBuf;
|
||||
use http_utils::error::HttpErrorBody;
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use reqwest::{IntoUrl, Method};
|
||||
use safekeeper_api::models::TimelineCreateRequest;
|
||||
use safekeeper_client::mgmt_api;
|
||||
use thiserror::Error;
|
||||
use utils::auth::{Claims, Scope};
|
||||
use utils::id::NodeId;
|
||||
@@ -35,25 +34,14 @@ pub enum SafekeeperHttpError {
|
||||
|
||||
type Result<T> = result::Result<T, SafekeeperHttpError>;
|
||||
|
||||
pub(crate) trait ResponseErrorMessageExt: Sized {
|
||||
fn error_from_body(self) -> impl Future<Output = Result<Self>> + Send;
|
||||
}
|
||||
|
||||
impl ResponseErrorMessageExt for reqwest::Response {
|
||||
async fn error_from_body(self) -> Result<Self> {
|
||||
let status = self.status();
|
||||
if !(status.is_client_error() || status.is_server_error()) {
|
||||
return Ok(self);
|
||||
}
|
||||
|
||||
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
|
||||
let url = self.url().to_owned();
|
||||
Err(SafekeeperHttpError::Response(
|
||||
match self.json::<HttpErrorBody>().await {
|
||||
Ok(err_body) => format!("Error: {}", err_body.msg),
|
||||
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
|
||||
},
|
||||
))
|
||||
fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError {
|
||||
use mgmt_api::Error::*;
|
||||
match err {
|
||||
ApiError(_, str) => SafekeeperHttpError::Response(str),
|
||||
Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()),
|
||||
ReceiveBody(err) => SafekeeperHttpError::Transport(err),
|
||||
ReceiveErrorBody(err) => SafekeeperHttpError::Response(err),
|
||||
Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,9 +58,8 @@ pub struct SafekeeperNode {
|
||||
|
||||
pub pg_connection_config: PgConnectionConfig,
|
||||
pub env: LocalEnv,
|
||||
pub http_client: reqwest::Client,
|
||||
pub http_client: mgmt_api::Client,
|
||||
pub listen_addr: String,
|
||||
pub http_base_url: String,
|
||||
}
|
||||
|
||||
impl SafekeeperNode {
|
||||
@@ -82,13 +69,14 @@ impl SafekeeperNode {
|
||||
} else {
|
||||
"127.0.0.1".to_string()
|
||||
};
|
||||
let jwt = None;
|
||||
let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port);
|
||||
SafekeeperNode {
|
||||
id: conf.id,
|
||||
conf: conf.clone(),
|
||||
pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port),
|
||||
env: env.clone(),
|
||||
http_client: env.create_http_client(),
|
||||
http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port),
|
||||
http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt),
|
||||
listen_addr,
|
||||
}
|
||||
}
|
||||
@@ -278,20 +266,19 @@ impl SafekeeperNode {
|
||||
)
|
||||
}
|
||||
|
||||
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> reqwest::RequestBuilder {
|
||||
// TODO: authentication
|
||||
//if self.env.auth_type == AuthType::NeonJWT {
|
||||
// builder = builder.bearer_auth(&self.env.safekeeper_auth_token)
|
||||
//}
|
||||
self.http_client.request(method, url)
|
||||
pub async fn check_status(&self) -> Result<()> {
|
||||
self.http_client
|
||||
.status()
|
||||
.await
|
||||
.map_err(err_from_client_err)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn check_status(&self) -> Result<()> {
|
||||
self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
|
||||
.send()
|
||||
.await?
|
||||
.error_from_body()
|
||||
.await?;
|
||||
pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> {
|
||||
self.http_client
|
||||
.create_timeline(req)
|
||||
.await
|
||||
.map_err(err_from_client_err)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ enum Command {
|
||||
#[arg(long)]
|
||||
scheduling: Option<NodeSchedulingPolicy>,
|
||||
},
|
||||
// Set a node status as deleted.
|
||||
NodeDelete {
|
||||
#[arg(long)]
|
||||
node_id: NodeId,
|
||||
@@ -69,6 +70,11 @@ enum Command {
|
||||
#[arg(long)]
|
||||
force: bool,
|
||||
},
|
||||
/// Delete a tombstone of node from the storage controller.
|
||||
NodeDeleteTombstone {
|
||||
#[arg(long)]
|
||||
node_id: NodeId,
|
||||
},
|
||||
/// Modify a tenant's policies in the storage controller
|
||||
TenantPolicy {
|
||||
#[arg(long)]
|
||||
@@ -86,6 +92,8 @@ enum Command {
|
||||
},
|
||||
/// List nodes known to the storage controller
|
||||
Nodes {},
|
||||
/// List soft deleted nodes known to the storage controller
|
||||
NodeTombstones {},
|
||||
/// List tenants known to the storage controller
|
||||
Tenants {
|
||||
/// If this field is set, it will list the tenants on a specific node
|
||||
@@ -938,6 +946,39 @@ async fn main() -> anyhow::Result<()> {
|
||||
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
|
||||
.await?;
|
||||
}
|
||||
Command::NodeDeleteTombstone { node_id } => {
|
||||
storcon_client
|
||||
.dispatch::<(), ()>(
|
||||
Method::DELETE,
|
||||
format!("debug/v1/tombstone/{node_id}"),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::NodeTombstones {} => {
|
||||
let mut resp = storcon_client
|
||||
.dispatch::<(), Vec<NodeDescribeResponse>>(
|
||||
Method::GET,
|
||||
"debug/v1/tombstone".to_string(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr));
|
||||
|
||||
let mut table = comfy_table::Table::new();
|
||||
table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]);
|
||||
for node in resp {
|
||||
table.add_row([
|
||||
format!("{}", node.id),
|
||||
node.listen_http_addr,
|
||||
node.availability_zone_id,
|
||||
format!("{:?}", node.scheduling),
|
||||
format!("{:?}", node.availability),
|
||||
]);
|
||||
}
|
||||
println!("{table}");
|
||||
}
|
||||
Command::TenantSetTimeBasedEviction {
|
||||
tenant_id,
|
||||
period,
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
#!/bin/bash
|
||||
#!/bin/sh
|
||||
set -ex
|
||||
cd "$(dirname "$0")"
|
||||
if [[ ${PG_VERSION} = v17 ]]; then
|
||||
sed -i '/computed_columns/d' regress/core/tests.mk
|
||||
fi
|
||||
patch -p1 <postgis-no-upgrade-test.patch
|
||||
trap 'echo Cleaning up; patch -R -p1 <postgis-no-upgrade-test.patch' EXIT
|
||||
patch -p1 <"postgis-common-${PG_VERSION}.patch"
|
||||
trap 'echo Cleaning up; patch -R -p1 <postgis-common-${PG_VERSION}.patch' EXIT
|
||||
make installcheck-base
|
||||
@@ -1,3 +1,19 @@
|
||||
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
|
||||
index 3abd7bc..64a9254 100644
|
||||
--- a/regress/core/tests.mk
|
||||
+++ b/regress/core/tests.mk
|
||||
@@ -144,11 +144,6 @@ TESTS_SLOW = \
|
||||
$(top_srcdir)/regress/core/concave_hull_hard \
|
||||
$(top_srcdir)/regress/core/knn_recheck
|
||||
|
||||
-ifeq ($(shell expr "$(POSTGIS_PGSQL_VERSION)" ">=" 120),1)
|
||||
- TESTS += \
|
||||
- $(top_srcdir)/regress/core/computed_columns
|
||||
-endif
|
||||
-
|
||||
ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1)
|
||||
# GEOS-3.7 adds:
|
||||
# ST_FrechetDistance
|
||||
diff --git a/regress/runtest.mk b/regress/runtest.mk
|
||||
index c051f03..010e493 100644
|
||||
--- a/regress/runtest.mk
|
||||
35
docker-compose/ext-src/postgis-src/postgis-common-v17.patch
Normal file
35
docker-compose/ext-src/postgis-src/postgis-common-v17.patch
Normal file
@@ -0,0 +1,35 @@
|
||||
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
|
||||
index 9e05244..90987df 100644
|
||||
--- a/regress/core/tests.mk
|
||||
+++ b/regress/core/tests.mk
|
||||
@@ -143,8 +143,7 @@ TESTS += \
|
||||
$(top_srcdir)/regress/core/oriented_envelope \
|
||||
$(top_srcdir)/regress/core/point_coordinates \
|
||||
$(top_srcdir)/regress/core/out_geojson \
|
||||
- $(top_srcdir)/regress/core/wrapx \
|
||||
- $(top_srcdir)/regress/core/computed_columns
|
||||
+ $(top_srcdir)/regress/core/wrapx
|
||||
|
||||
# Slow slow tests
|
||||
TESTS_SLOW = \
|
||||
diff --git a/regress/runtest.mk b/regress/runtest.mk
|
||||
index 4b95b7e..449d5a2 100644
|
||||
--- a/regress/runtest.mk
|
||||
+++ b/regress/runtest.mk
|
||||
@@ -24,16 +24,6 @@ check-regress:
|
||||
|
||||
@POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(RUNTESTFLAGS_INTERNAL) $(TESTS)
|
||||
|
||||
- @if echo "$(RUNTESTFLAGS)" | grep -vq -- --upgrade; then \
|
||||
- echo "Running upgrade test as RUNTESTFLAGS did not contain that"; \
|
||||
- POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl \
|
||||
- --upgrade \
|
||||
- $(RUNTESTFLAGS) \
|
||||
- $(RUNTESTFLAGS_INTERNAL) \
|
||||
- $(TESTS); \
|
||||
- else \
|
||||
- echo "Skipping upgrade test as RUNTESTFLAGS already requested upgrades"; \
|
||||
- fi
|
||||
|
||||
check-long:
|
||||
$(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(TESTS) $(TESTS_SLOW)
|
||||
@@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644
|
||||
DROP SCHEMA tm CASCADE;
|
||||
+
|
||||
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
|
||||
index 3abd7bc..94903c3 100644
|
||||
index 64a9254..94903c3 100644
|
||||
--- a/regress/core/tests.mk
|
||||
+++ b/regress/core/tests.mk
|
||||
@@ -23,7 +23,6 @@ current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
@@ -160,18 +160,6 @@ index 3abd7bc..94903c3 100644
|
||||
$(top_srcdir)/regress/core/wkb \
|
||||
$(top_srcdir)/regress/core/wkt \
|
||||
$(top_srcdir)/regress/core/wmsservers \
|
||||
@@ -144,11 +140,6 @@ TESTS_SLOW = \
|
||||
$(top_srcdir)/regress/core/concave_hull_hard \
|
||||
$(top_srcdir)/regress/core/knn_recheck
|
||||
|
||||
-ifeq ($(shell expr "$(POSTGIS_PGSQL_VERSION)" ">=" 120),1)
|
||||
- TESTS += \
|
||||
- $(top_srcdir)/regress/core/computed_columns
|
||||
-endif
|
||||
-
|
||||
ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1)
|
||||
# GEOS-3.7 adds:
|
||||
# ST_FrechetDistance
|
||||
diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk
|
||||
index 1fc77ac..c3cb9de 100644
|
||||
--- a/regress/loader/tests.mk
|
||||
|
||||
@@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644
|
||||
DROP SCHEMA tm CASCADE;
|
||||
+
|
||||
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
|
||||
index 9e05244..a63a3e1 100644
|
||||
index 90987df..74fe3f1 100644
|
||||
--- a/regress/core/tests.mk
|
||||
+++ b/regress/core/tests.mk
|
||||
@@ -16,14 +16,13 @@ POSTGIS_PGSQL_VERSION=170
|
||||
@@ -168,16 +168,6 @@ index 9e05244..a63a3e1 100644
|
||||
$(top_srcdir)/regress/core/wkb \
|
||||
$(top_srcdir)/regress/core/wkt \
|
||||
$(top_srcdir)/regress/core/wmsservers \
|
||||
@@ -143,8 +139,7 @@ TESTS += \
|
||||
$(top_srcdir)/regress/core/oriented_envelope \
|
||||
$(top_srcdir)/regress/core/point_coordinates \
|
||||
$(top_srcdir)/regress/core/out_geojson \
|
||||
- $(top_srcdir)/regress/core/wrapx \
|
||||
- $(top_srcdir)/regress/core/computed_columns
|
||||
+ $(top_srcdir)/regress/core/wrapx
|
||||
|
||||
# Slow slow tests
|
||||
TESTS_SLOW = \
|
||||
diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk
|
||||
index ac4f8ad..4bad4fc 100644
|
||||
--- a/regress/loader/tests.mk
|
||||
|
||||
@@ -10,8 +10,8 @@ psql -d contrib_regression -c "ALTER DATABASE contrib_regression SET TimeZone='U
|
||||
-c "CREATE EXTENSION postgis_tiger_geocoder CASCADE" \
|
||||
-c "CREATE EXTENSION postgis_raster SCHEMA public" \
|
||||
-c "CREATE EXTENSION postgis_sfcgal SCHEMA public"
|
||||
patch -p1 <postgis-no-upgrade-test.patch
|
||||
patch -p1 <"postgis-common-${PG_VERSION}.patch"
|
||||
patch -p1 <"postgis-regular-${PG_VERSION}.patch"
|
||||
psql -d contrib_regression -f raster_outdb_template.sql
|
||||
trap 'patch -R -p1 <postgis-no-upgrade-test.patch && patch -R -p1 <"postgis-regular-${PG_VERSION}.patch"' EXIT
|
||||
trap 'patch -R -p1 <postgis-regular-${PG_VERSION}.patch && patch -R -p1 <"postgis-common-${PG_VERSION}.patch"' EXIT
|
||||
POSTGIS_REGRESS_DB=contrib_regression RUNTESTFLAGS=--nocreate make installcheck-base
|
||||
@@ -8,6 +8,7 @@ anyhow.workspace = true
|
||||
axum-extra.workspace = true
|
||||
axum.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
prometheus.workspace = true
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
//! for large computes.
|
||||
mod app;
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use tracing::info;
|
||||
use utils::logging;
|
||||
|
||||
@@ -12,9 +14,26 @@ const fn max_upload_file_limit() -> usize {
|
||||
100 * 1024 * 1024
|
||||
}
|
||||
|
||||
const fn listen() -> SocketAddr {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(exclusive = true)]
|
||||
config_file: Option<String>,
|
||||
#[arg(long, default_value = "false", requires = "config")]
|
||||
/// to allow testing k8s helm chart where we don't have s3 credentials
|
||||
no_s3_check_on_startup: bool,
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// inline config mode for k8s helm chart
|
||||
config: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
struct Config {
|
||||
#[serde(default = "listen")]
|
||||
listen: std::net::SocketAddr,
|
||||
pemfile: camino::Utf8PathBuf,
|
||||
#[serde(flatten)]
|
||||
@@ -31,13 +50,18 @@ async fn main() -> anyhow::Result<()> {
|
||||
logging::Output::Stdout,
|
||||
)?;
|
||||
|
||||
let config: String = std::env::args().skip(1).take(1).collect();
|
||||
if config.is_empty() {
|
||||
anyhow::bail!("Usage: endpoint_storage config.json")
|
||||
}
|
||||
info!("Reading config from {config}");
|
||||
let config = std::fs::read_to_string(config.clone())?;
|
||||
let config: Config = serde_json::from_str(&config).context("parsing config")?;
|
||||
let args = Args::parse();
|
||||
let config: Config = if let Some(config_path) = args.config_file {
|
||||
info!("Reading config from {config_path}");
|
||||
let config = std::fs::read_to_string(config_path)?;
|
||||
serde_json::from_str(&config).context("parsing config")?
|
||||
} else if let Some(config) = args.config {
|
||||
info!("Reading inline config");
|
||||
serde_json::from_str(&config).context("parsing config")?
|
||||
} else {
|
||||
anyhow::bail!("Supply either config file path or --config=inline-config");
|
||||
};
|
||||
|
||||
info!("Reading pemfile from {}", config.pemfile.clone());
|
||||
let pemfile = std::fs::read(config.pemfile.clone())?;
|
||||
info!("Loading public key from {}", config.pemfile.clone());
|
||||
@@ -48,7 +72,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
|
||||
let cancel = tokio_util::sync::CancellationToken::new();
|
||||
app::check_storage_permissions(&storage, cancel.clone()).await?;
|
||||
if !args.no_s3_check_on_startup {
|
||||
app::check_storage_permissions(&storage, cancel.clone()).await?;
|
||||
}
|
||||
|
||||
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
|
||||
auth,
|
||||
|
||||
@@ -344,6 +344,35 @@ impl Default for ShardSchedulingPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
|
||||
pub enum NodeLifecycle {
|
||||
Active,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
impl FromStr for NodeLifecycle {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"active" => Ok(Self::Active),
|
||||
"deleted" => Ok(Self::Deleted),
|
||||
_ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NodeLifecycle> for String {
|
||||
fn from(value: NodeLifecycle) -> String {
|
||||
use NodeLifecycle::*;
|
||||
match value {
|
||||
Active => "active",
|
||||
Deleted => "deleted",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
|
||||
pub enum NodeSchedulingPolicy {
|
||||
Active,
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{Error, cancel_query_raw, connect_socket};
|
||||
pub(crate) async fn cancel_query<T>(
|
||||
config: Option<SocketConfig>,
|
||||
ssl_mode: SslMode,
|
||||
mut tls: T,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Result<(), Error>
|
||||
|
||||
@@ -17,7 +17,6 @@ use crate::{Client, Connection, Error};
|
||||
|
||||
/// TLS configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum SslMode {
|
||||
/// Do not use TLS.
|
||||
Disable,
|
||||
@@ -231,7 +230,7 @@ impl Config {
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
pub async fn connect<T>(
|
||||
&self,
|
||||
tls: T,
|
||||
tls: &T,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection};
|
||||
|
||||
pub async fn connect<T>(
|
||||
mut tls: T,
|
||||
tls: &T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
|
||||
@@ -47,7 +47,7 @@ pub trait MakeTlsConnect<S> {
|
||||
/// Creates a new `TlsConnect`or.
|
||||
///
|
||||
/// The domain name is provided for certificate verification and SNI.
|
||||
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
fn make_tls_connect(&self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
}
|
||||
|
||||
/// An asynchronous function wrapping a stream in a TLS session.
|
||||
@@ -85,7 +85,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
|
||||
type TlsConnect = NoTls;
|
||||
type Error = NoTlsError;
|
||||
|
||||
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
Ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use utils::pageserver_feedback::PageserverFeedback;
|
||||
use crate::membership::Configuration;
|
||||
use crate::{ServerInfo, Term};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SafekeeperStatus {
|
||||
pub id: NodeId,
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use pageserver::deletion_queue::DeletionQueue;
|
||||
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
|
||||
use pageserver::feature_resolver::FeatureResolver;
|
||||
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
|
||||
use pageserver::page_service::GrpcPageServiceHandler;
|
||||
use pageserver::task_mgr::{
|
||||
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
|
||||
};
|
||||
@@ -814,7 +815,7 @@ fn start_pageserver(
|
||||
// necessary?
|
||||
let mut page_service_grpc = None;
|
||||
if let Some(grpc_listener) = grpc_listener {
|
||||
page_service_grpc = Some(page_service::spawn_grpc(
|
||||
page_service_grpc = Some(GrpcPageServiceHandler::spawn(
|
||||
tenant_manager.clone(),
|
||||
grpc_auth,
|
||||
otel_guard.as_ref().map(|g| g.dispatch.clone()),
|
||||
|
||||
@@ -169,99 +169,6 @@ pub fn spawn(
|
||||
Listener { cancel, task }
|
||||
}
|
||||
|
||||
/// Spawns a gRPC server for the page service.
|
||||
///
|
||||
/// TODO: move this onto GrpcPageServiceHandler::spawn().
|
||||
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
|
||||
/// need to reimplement the TCP+TLS accept loop ourselves.
|
||||
pub fn spawn_grpc(
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
auth: Option<Arc<SwappableJwtAuth>>,
|
||||
perf_trace_dispatch: Option<Dispatch>,
|
||||
get_vectored_concurrent_io: GetVectoredConcurrentIo,
|
||||
listener: std::net::TcpListener,
|
||||
) -> anyhow::Result<CancellableTask> {
|
||||
let cancel = CancellationToken::new();
|
||||
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
|
||||
.download_behavior(DownloadBehavior::Download)
|
||||
.perf_span_dispatch(perf_trace_dispatch)
|
||||
.detached_child();
|
||||
let gate = Gate::default();
|
||||
|
||||
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
|
||||
// port early during startup.
|
||||
let incoming = {
|
||||
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
|
||||
listener.set_nonblocking(true)?;
|
||||
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?)
|
||||
.with_nodelay(Some(GRPC_TCP_NODELAY))
|
||||
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
|
||||
};
|
||||
|
||||
// Set up the gRPC server.
|
||||
//
|
||||
// TODO: consider tuning window sizes.
|
||||
let mut server = tonic::transport::Server::builder()
|
||||
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
|
||||
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
|
||||
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
|
||||
|
||||
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
|
||||
//
|
||||
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
|
||||
//
|
||||
// * Layers: allow async code, can run code after the service response. However, only has access
|
||||
// to the raw HTTP request/response, not the gRPC types.
|
||||
let page_service_handler = GrpcPageServiceHandler {
|
||||
tenant_manager,
|
||||
ctx,
|
||||
gate_guard: gate.enter().expect("gate was just created"),
|
||||
get_vectored_concurrent_io,
|
||||
};
|
||||
|
||||
let observability_layer = ObservabilityLayer;
|
||||
let mut tenant_interceptor = TenantMetadataInterceptor;
|
||||
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
|
||||
|
||||
let page_service = tower::ServiceBuilder::new()
|
||||
// Create tracing span and record request start time.
|
||||
.layer(observability_layer)
|
||||
// Intercept gRPC requests.
|
||||
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
|
||||
// Extract tenant metadata.
|
||||
req = tenant_interceptor.call(req)?;
|
||||
// Authenticate tenant JWT token.
|
||||
req = auth_interceptor.call(req)?;
|
||||
Ok(req)
|
||||
}))
|
||||
.service(proto::PageServiceServer::new(page_service_handler));
|
||||
let server = server.add_service(page_service);
|
||||
|
||||
// Reflection service for use with e.g. grpcurl.
|
||||
let reflection_service = tonic_reflection::server::Builder::configure()
|
||||
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
|
||||
.build_v1()?;
|
||||
let server = server.add_service(reflection_service);
|
||||
|
||||
// Spawn server task.
|
||||
let task_cancel = cancel.clone();
|
||||
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
|
||||
"grpc listener",
|
||||
async move {
|
||||
let result = server
|
||||
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
|
||||
.await;
|
||||
if result.is_ok() {
|
||||
// TODO: revisit shutdown logic once page service is implemented.
|
||||
gate.close().await;
|
||||
}
|
||||
result
|
||||
},
|
||||
));
|
||||
|
||||
Ok(CancellableTask { task, cancel })
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
pub async fn stop_accepting(self) -> Connections {
|
||||
self.cancel.cancel();
|
||||
@@ -3366,6 +3273,101 @@ pub struct GrpcPageServiceHandler {
|
||||
}
|
||||
|
||||
impl GrpcPageServiceHandler {
|
||||
/// Spawns a gRPC server for the page service.
|
||||
///
|
||||
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
|
||||
/// need to reimplement the TCP+TLS accept loop ourselves.
|
||||
pub fn spawn(
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
auth: Option<Arc<SwappableJwtAuth>>,
|
||||
perf_trace_dispatch: Option<Dispatch>,
|
||||
get_vectored_concurrent_io: GetVectoredConcurrentIo,
|
||||
listener: std::net::TcpListener,
|
||||
) -> anyhow::Result<CancellableTask> {
|
||||
let cancel = CancellationToken::new();
|
||||
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
|
||||
.download_behavior(DownloadBehavior::Download)
|
||||
.perf_span_dispatch(perf_trace_dispatch)
|
||||
.detached_child();
|
||||
let gate = Gate::default();
|
||||
|
||||
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
|
||||
// port early during startup.
|
||||
let incoming = {
|
||||
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
|
||||
listener.set_nonblocking(true)?;
|
||||
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(
|
||||
listener,
|
||||
)?)
|
||||
.with_nodelay(Some(GRPC_TCP_NODELAY))
|
||||
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
|
||||
};
|
||||
|
||||
// Set up the gRPC server.
|
||||
//
|
||||
// TODO: consider tuning window sizes.
|
||||
let mut server = tonic::transport::Server::builder()
|
||||
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
|
||||
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
|
||||
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
|
||||
|
||||
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
|
||||
//
|
||||
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
|
||||
//
|
||||
// * Layers: allow async code, can run code after the service response. However, only has access
|
||||
// to the raw HTTP request/response, not the gRPC types.
|
||||
let page_service_handler = GrpcPageServiceHandler {
|
||||
tenant_manager,
|
||||
ctx,
|
||||
gate_guard: gate.enter().expect("gate was just created"),
|
||||
get_vectored_concurrent_io,
|
||||
};
|
||||
|
||||
let observability_layer = ObservabilityLayer;
|
||||
let mut tenant_interceptor = TenantMetadataInterceptor;
|
||||
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
|
||||
|
||||
let page_service = tower::ServiceBuilder::new()
|
||||
// Create tracing span and record request start time.
|
||||
.layer(observability_layer)
|
||||
// Intercept gRPC requests.
|
||||
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
|
||||
// Extract tenant metadata.
|
||||
req = tenant_interceptor.call(req)?;
|
||||
// Authenticate tenant JWT token.
|
||||
req = auth_interceptor.call(req)?;
|
||||
Ok(req)
|
||||
}))
|
||||
// Run the page service.
|
||||
.service(proto::PageServiceServer::new(page_service_handler));
|
||||
let server = server.add_service(page_service);
|
||||
|
||||
// Reflection service for use with e.g. grpcurl.
|
||||
let reflection_service = tonic_reflection::server::Builder::configure()
|
||||
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
|
||||
.build_v1()?;
|
||||
let server = server.add_service(reflection_service);
|
||||
|
||||
// Spawn server task.
|
||||
let task_cancel = cancel.clone();
|
||||
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
|
||||
"grpc listener",
|
||||
async move {
|
||||
let result = server
|
||||
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
|
||||
.await;
|
||||
if result.is_ok() {
|
||||
// TODO: revisit shutdown logic once page service is implemented.
|
||||
gate.close().await;
|
||||
}
|
||||
result
|
||||
},
|
||||
));
|
||||
|
||||
Ok(CancellableTask { task, cancel })
|
||||
}
|
||||
|
||||
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
|
||||
/// relations and their sizes, as well as SLRU segments and similar data.
|
||||
#[allow(clippy::result_large_err)]
|
||||
|
||||
@@ -1671,7 +1671,12 @@ impl TenantManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Shut down the parent shard, and erase it from disk
|
||||
// Phase 5: Shut down the parent shard. We leave it on disk in case the split fails and we
|
||||
// have to roll back to the parent shard, avoiding a cold start. It will be cleaned up once
|
||||
// the storage controller commits the split, or if all else fails, on the next restart.
|
||||
//
|
||||
// TODO: We don't flush the ephemeral layer here, because the split is likely to succeed and
|
||||
// catching up the parent should be reasonably quick. Consider using FreezeAndFlush instead.
|
||||
let (_guard, progress) = completion::channel();
|
||||
match parent.shutdown(progress, ShutdownMode::Hard).await {
|
||||
Ok(()) => {}
|
||||
@@ -1679,11 +1684,6 @@ impl TenantManager {
|
||||
other.wait().await;
|
||||
}
|
||||
}
|
||||
let local_tenant_directory = self.conf.tenant_path(&tenant_shard_id);
|
||||
let tmp_path = safe_rename_tenant_dir(&local_tenant_directory)
|
||||
.await
|
||||
.with_context(|| format!("local tenant directory {local_tenant_directory:?} rename"))?;
|
||||
self.background_purges.spawn(tmp_path);
|
||||
|
||||
fail::fail_point!("shard-split-pre-finish", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
@@ -1846,42 +1846,70 @@ impl TenantManager {
|
||||
shutdown_all_tenants0(self.tenants).await
|
||||
}
|
||||
|
||||
/// Detaches a tenant, and removes its local files asynchronously.
|
||||
///
|
||||
/// File removal is idempotent: even if the tenant has already been removed, this will still
|
||||
/// remove any local files. This is used during shard splits, where we leave the parent shard's
|
||||
/// files around in case we have to roll back the split.
|
||||
pub(crate) async fn detach_tenant(
|
||||
&self,
|
||||
conf: &'static PageServerConf,
|
||||
tenant_shard_id: TenantShardId,
|
||||
deletion_queue_client: &DeletionQueueClient,
|
||||
) -> Result<(), TenantStateError> {
|
||||
let tmp_path = self
|
||||
if let Some(tmp_path) = self
|
||||
.detach_tenant0(conf, tenant_shard_id, deletion_queue_client)
|
||||
.await?;
|
||||
self.background_purges.spawn(tmp_path);
|
||||
.await?
|
||||
{
|
||||
self.background_purges.spawn(tmp_path);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Detaches a tenant. This renames the tenant directory to a temporary path and returns it,
|
||||
/// allowing the caller to delete it asynchronously. Returns None if the dir is already removed.
|
||||
async fn detach_tenant0(
|
||||
&self,
|
||||
conf: &'static PageServerConf,
|
||||
tenant_shard_id: TenantShardId,
|
||||
deletion_queue_client: &DeletionQueueClient,
|
||||
) -> Result<Utf8PathBuf, TenantStateError> {
|
||||
) -> Result<Option<Utf8PathBuf>, TenantStateError> {
|
||||
let tenant_dir_rename_operation = |tenant_id_to_clean: TenantShardId| async move {
|
||||
let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean);
|
||||
if !tokio::fs::try_exists(&local_tenant_directory).await? {
|
||||
// If the tenant directory doesn't exist, it's already cleaned up.
|
||||
return Ok(None);
|
||||
}
|
||||
safe_rename_tenant_dir(&local_tenant_directory)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("local tenant directory {local_tenant_directory:?} rename")
|
||||
})
|
||||
.map(Some)
|
||||
};
|
||||
|
||||
let removal_result = remove_tenant_from_memory(
|
||||
let mut removal_result = remove_tenant_from_memory(
|
||||
self.tenants,
|
||||
tenant_shard_id,
|
||||
tenant_dir_rename_operation(tenant_shard_id),
|
||||
)
|
||||
.await;
|
||||
|
||||
// If the tenant was not found, it was likely already removed. Attempt to remove the tenant
|
||||
// directory on disk anyway. For example, during shard splits, we shut down and remove the
|
||||
// parent shard, but leave its directory on disk in case we have to roll back the split.
|
||||
//
|
||||
// TODO: it would be better to leave the parent shard attached until the split is committed.
|
||||
// This will be needed by the gRPC page service too, such that a compute can continue to
|
||||
// read from the parent shard until it's notified about the new child shards. See:
|
||||
// <https://github.com/neondatabase/neon/issues/11728>.
|
||||
if let Err(TenantStateError::SlotError(TenantSlotError::NotFound(_))) = removal_result {
|
||||
removal_result = tenant_dir_rename_operation(tenant_shard_id)
|
||||
.await
|
||||
.map_err(TenantStateError::Other);
|
||||
}
|
||||
|
||||
// Flush pending deletions, so that they have a good chance of passing validation
|
||||
// before this tenant is potentially re-attached elsewhere.
|
||||
deletion_queue_client.flush_advisory();
|
||||
|
||||
@@ -1055,8 +1055,8 @@ pub(crate) enum WaitLsnWaiter<'a> {
|
||||
/// Argument to [`Timeline::shutdown`].
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) enum ShutdownMode {
|
||||
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then
|
||||
/// also to remote storage. This method can easily take multiple seconds for a busy timeline.
|
||||
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can
|
||||
/// take multiple seconds for a busy timeline.
|
||||
///
|
||||
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
|
||||
/// the call to [`Timeline::shutdown`].
|
||||
|
||||
@@ -18,11 +18,6 @@ pub(super) async fn authenticate(
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
debug!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
debug!("auth endpoint chooses SCRAM");
|
||||
|
||||
|
||||
@@ -6,10 +6,9 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use super::ComputeCredentialKeys;
|
||||
use crate::auth::IpPattern;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::cache::Cached;
|
||||
use crate::compute::AuthInfo;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::cplane_proxy_v1;
|
||||
@@ -98,15 +97,11 @@ impl ConsoleRedirectBackend {
|
||||
ctx: &RequestContext,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(
|
||||
ConsoleRedirectNodeInfo,
|
||||
ComputeUserInfo,
|
||||
Option<Vec<IpPattern>>,
|
||||
)> {
|
||||
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
authenticate(ctx, auth_config, &self.console_uri, client)
|
||||
.await
|
||||
.map(|(node_info, user_info, ip_allowlist)| {
|
||||
(ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist)
|
||||
.map(|(node_info, auth_info, user_info)| {
|
||||
(ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -121,10 +116,6 @@ impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
Ok(Cached::new_uncached(self.0.clone()))
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&ComputeCredentialKeys::None
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
@@ -132,7 +123,7 @@ async fn authenticate(
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(NodeInfo, ComputeUserInfo, Option<Vec<IpPattern>>)> {
|
||||
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
|
||||
|
||||
// registering waiter can fail if we get unlucky with rng.
|
||||
@@ -192,10 +183,24 @@ async fn authenticate(
|
||||
|
||||
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
|
||||
config.dbname(&db_info.dbname).user(&db_info.user);
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
let ssl_mode = if db_info.host.contains("--") {
|
||||
// we need TLS connection with SNI info to properly route it
|
||||
SslMode::Require
|
||||
} else {
|
||||
SslMode::Disable
|
||||
};
|
||||
|
||||
let conn_info = compute::ConnectInfo {
|
||||
host: db_info.host.into(),
|
||||
port: db_info.port,
|
||||
ssl_mode,
|
||||
host_addr: None,
|
||||
};
|
||||
let auth_info =
|
||||
AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
|
||||
|
||||
let user: RoleName = db_info.user.into();
|
||||
let user_info = ComputeUserInfo {
|
||||
@@ -209,26 +214,12 @@ async fn authenticate(
|
||||
ctx.set_project(db_info.aux.clone());
|
||||
info!("woken up a compute node");
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
if db_info.host.contains("--") {
|
||||
// we need TLS connection with SNI info to properly route it
|
||||
config.ssl_mode(SslMode::Require);
|
||||
} else {
|
||||
config.ssl_mode(SslMode::Disable);
|
||||
}
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password.as_ref());
|
||||
}
|
||||
|
||||
Ok((
|
||||
NodeInfo {
|
||||
config,
|
||||
conn_info,
|
||||
aux: db_info.aux,
|
||||
},
|
||||
auth_info,
|
||||
user_info,
|
||||
db_info.allowed_ips,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use postgres_client::config::SslMode;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::jwt::{AuthRule, FetchAuthRules};
|
||||
use crate::auth::backend::jwt::FetchAuthRulesError;
|
||||
use crate::compute::ConnCfg;
|
||||
use crate::compute::ConnectInfo;
|
||||
use crate::compute_ctl::ComputeCtlApi;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::NodeInfo;
|
||||
@@ -29,7 +30,12 @@ impl LocalBackend {
|
||||
api: http::Endpoint::new(compute_ctl, http::new_client()),
|
||||
},
|
||||
node_info: NodeInfo {
|
||||
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
|
||||
conn_info: ConnectInfo {
|
||||
host_addr: Some(postgres_addr.ip()),
|
||||
host: postgres_addr.ip().to_string().into(),
|
||||
port: postgres_addr.port(),
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
// TODO(conrad): make this better reflect compute info rather than endpoint info.
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
|
||||
|
||||
@@ -168,8 +168,6 @@ impl ComputeUserInfo {
|
||||
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) enum ComputeCredentialKeys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
JwtPayload(Vec<u8>),
|
||||
None,
|
||||
@@ -419,13 +417,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::ControlPlane(_, creds) => &creds.keys,
|
||||
Self::Local(_) => &ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -169,13 +169,6 @@ pub(crate) async fn validate_password_and_exchange(
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
// test only
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
|
||||
password.to_owned(),
|
||||
)))
|
||||
}
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
|
||||
8
proxy/src/cache/project_info.rs
vendored
8
proxy/src/cache/project_info.rs
vendored
@@ -18,6 +18,7 @@ use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ProjectInfoCache {
|
||||
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
|
||||
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
|
||||
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
|
||||
@@ -100,6 +101,13 @@ pub struct ProjectInfoCacheImpl {
|
||||
|
||||
#[async_trait]
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
|
||||
info!("invalidating endpoint access for `{endpoint_id}`");
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_endpoint();
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
|
||||
info!("invalidating endpoint access for project `{project_id}`");
|
||||
let endpoints = self
|
||||
|
||||
@@ -24,7 +24,6 @@ use crate::pqproto::CancelKeyData;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::keys::KeyPrefix;
|
||||
use crate::redis::kv_ops::RedisKVClient;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
type IpSubnetKey = IpNet;
|
||||
|
||||
@@ -497,10 +496,8 @@ impl CancelClosure {
|
||||
) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
|
||||
let mut mk_tls =
|
||||
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
compute_config,
|
||||
&self.hostname,
|
||||
)
|
||||
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
mod tls;
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use postgres_client::config::{AuthKeys, SslMode};
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_client::{CancelToken, NoTls, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::parse_endpoint_param;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::compute::tls::TlsError;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
@@ -25,7 +28,6 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::types::Host;
|
||||
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
@@ -38,10 +40,7 @@ pub(crate) enum ConnectionError {
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] InvalidDnsNameError),
|
||||
TlsError(#[from] TlsError),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
WakeComputeError(#[from] WakeComputeError),
|
||||
@@ -73,7 +72,7 @@ impl UserFacingError for ConnectionError {
|
||||
ConnectionError::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
_ => COULD_NOT_CONNECT.to_owned(),
|
||||
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,7 +84,6 @@ impl ReportableError for ConnectionError {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
@@ -96,34 +94,85 @@ impl ReportableError for ConnectionError {
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `postgres_client` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
|
||||
pub enum Auth {
|
||||
/// Only used during console-redirect.
|
||||
Password(Vec<u8>),
|
||||
/// Used by sql-over-http, ws, tcp.
|
||||
Scram(Box<ScramKeys>),
|
||||
}
|
||||
|
||||
/// A config for authenticating to the compute node.
|
||||
pub(crate) struct AuthInfo {
|
||||
/// None for local-proxy, as we use trust-based localhost auth.
|
||||
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
|
||||
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
|
||||
auth: Option<Auth>,
|
||||
server_params: StartupMessageParams,
|
||||
|
||||
/// Console redirect sets user and database, we shouldn't re-use those from the params.
|
||||
skip_db_user: bool,
|
||||
}
|
||||
|
||||
/// Contains only the data needed to establish a secure connection to compute.
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectInfo {
|
||||
pub host_addr: Option<IpAddr>,
|
||||
pub host: Host,
|
||||
pub port: u16,
|
||||
pub ssl_mode: SslMode,
|
||||
}
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
pub(crate) fn new(host: String, port: u16) -> Self {
|
||||
Self(Box::new(postgres_client::Config::new(host, port)))
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub(crate) fn reuse_password(&mut self, other: Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
|
||||
if let Some(keys) = other.get_auth_keys() {
|
||||
self.auth_keys(keys);
|
||||
impl AuthInfo {
|
||||
pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
|
||||
let mut server_params = StartupMessageParams::default();
|
||||
server_params.insert("database", db);
|
||||
server_params.insert("user", user);
|
||||
Self {
|
||||
auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
|
||||
server_params,
|
||||
skip_db_user: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_host(&self) -> Host {
|
||||
match self.0.get_host() {
|
||||
postgres_client::config::Host::Tcp(s) => s.into(),
|
||||
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
|
||||
Self {
|
||||
auth: match keys {
|
||||
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
|
||||
Some(Auth::Scram(Box::new(*auth_keys)))
|
||||
}
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
|
||||
},
|
||||
server_params: StartupMessageParams::default(),
|
||||
skip_db_user: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
pub fn to_postgres_client_config(&self) -> postgres_client::Config {
|
||||
let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
|
||||
config.ssl_mode(self.ssl_mode);
|
||||
if let Some(host_addr) = self.host_addr {
|
||||
config.set_host_addr(host_addr);
|
||||
}
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthInfo {
|
||||
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
|
||||
match &self.auth {
|
||||
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
|
||||
Some(Auth::Password(pw)) => config.password(pw),
|
||||
None => &mut config,
|
||||
};
|
||||
for (k, v) in self.server_params.iter() {
|
||||
config.set_param(k, v);
|
||||
}
|
||||
config
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub(crate) fn set_startup_params(
|
||||
@@ -132,27 +181,26 @@ impl ConnCfg {
|
||||
arbitrary_params: bool,
|
||||
) {
|
||||
if !arbitrary_params {
|
||||
self.set_param("client_encoding", "UTF8");
|
||||
self.server_params.insert("client_encoding", "UTF8");
|
||||
}
|
||||
for (k, v) in params.iter() {
|
||||
match k {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Console redirect auth flow takes username from the console's response.
|
||||
"user" if self.user_is_set() => {}
|
||||
"database" if self.db_is_set() => {}
|
||||
"user" | "database" if self.skip_db_user => {}
|
||||
"options" => {
|
||||
if let Some(options) = filtered_options(v) {
|
||||
self.set_param(k, &options);
|
||||
self.server_params.insert(k, &options);
|
||||
}
|
||||
}
|
||||
"user" | "database" | "application_name" | "replication" => {
|
||||
self.set_param(k, v);
|
||||
self.server_params.insert(k, v);
|
||||
}
|
||||
|
||||
// if we allow arbitrary params, then we forward them through.
|
||||
// this is a flag for a period of backwards compatibility
|
||||
k if arbitrary_params => {
|
||||
self.set_param(k, v);
|
||||
self.server_params.insert(k, v);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@@ -160,25 +208,13 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = postgres_client::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// For now, let's make it easier to setup the config.
|
||||
impl std::ops::DerefMut for ConnCfg {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
|
||||
use postgres_client::config::Host;
|
||||
impl ConnectInfo {
|
||||
/// Establish a raw TCP+TLS connection to the compute node.
|
||||
async fn connect_raw(
|
||||
&self,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
|
||||
let timeout = config.timeout;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
let connect_with_timeout = |addrs| {
|
||||
@@ -208,34 +244,32 @@ impl ConnCfg {
|
||||
// We can't reuse connection establishing logic from `postgres_client` here,
|
||||
// because it has no means for extracting the underlying socket which we
|
||||
// require for our business.
|
||||
let port = self.0.get_port();
|
||||
let host = self.0.get_host();
|
||||
let port = self.port;
|
||||
let host = &*self.host;
|
||||
|
||||
let host = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
|
||||
let addrs = match self.0.get_host_addr() {
|
||||
let addrs = match self.host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => lookup_host((host, port)).await?.collect(),
|
||||
};
|
||||
|
||||
match connect_once(&*addrs).await {
|
||||
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
|
||||
Ok((sockaddr, stream)) => Ok((
|
||||
sockaddr,
|
||||
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
|
||||
)),
|
||||
Err(err) => {
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
Err(err)
|
||||
Err(TlsError::Connection(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub(crate) stream:
|
||||
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
@@ -248,28 +282,23 @@ pub(crate) struct PostgresConnection {
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
impl ConnectInfo {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
aux: MetricsAuxInfo,
|
||||
auth: &AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
|
||||
drop(pause);
|
||||
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
|
||||
// we setup SSL early in `ConnectInfo::connect_raw`.
|
||||
tmp_config.ssl_mode(SslMode::Disable);
|
||||
|
||||
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
)?;
|
||||
|
||||
// connect_raw() will not use TLS if sslmode is "disable"
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let connection = self.0.connect_raw(stream, tls).await?;
|
||||
let (socket_addr, stream) = self.connect_raw(config).await?;
|
||||
let connection = tmp_config.connect_raw(stream, NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
let RawConnection {
|
||||
@@ -282,13 +311,14 @@ impl ConnCfg {
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
|
||||
let stream = stream.into_inner();
|
||||
let MaybeTlsStream::Raw(stream) = stream.into_inner();
|
||||
|
||||
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
|
||||
info!(
|
||||
cold_start_info = ctx.cold_start_info().as_str(),
|
||||
"connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
|
||||
self.0.get_ssl_mode(),
|
||||
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
|
||||
self.host,
|
||||
self.ssl_mode,
|
||||
ctx.get_proxy_latency(),
|
||||
ctx.get_testodrome_id().unwrap_or_default(),
|
||||
);
|
||||
@@ -299,11 +329,11 @@ impl ConnCfg {
|
||||
socket_addr,
|
||||
CancelToken {
|
||||
socket_config: None,
|
||||
ssl_mode: self.0.get_ssl_mode(),
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
host.to_string(),
|
||||
self.host.to_string(),
|
||||
user_info,
|
||||
);
|
||||
|
||||
63
proxy/src/compute/tls.rs
Normal file
63
proxy/src/compute/tls.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use futures::FutureExt;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::pqproto::request_tls;
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TlsError {
|
||||
#[error(transparent)]
|
||||
Dns(#[from] InvalidDnsNameError),
|
||||
#[error(transparent)]
|
||||
Connection(#[from] std::io::Error),
|
||||
#[error("TLS required but not provided")]
|
||||
Required,
|
||||
}
|
||||
|
||||
impl CouldRetry for TlsError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
TlsError::Dns(_) => false,
|
||||
TlsError::Connection(err) => err.could_retry(),
|
||||
// perhaps compute didn't realise it supports TLS?
|
||||
TlsError::Required => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect_tls<S, T>(
|
||||
mut stream: S,
|
||||
mode: SslMode,
|
||||
tls: &T,
|
||||
host: &str,
|
||||
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: MakeTlsConnect<
|
||||
S,
|
||||
Error = InvalidDnsNameError,
|
||||
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
|
||||
>,
|
||||
{
|
||||
match mode {
|
||||
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
|
||||
SslMode::Prefer | SslMode::Require => {}
|
||||
}
|
||||
|
||||
if !request_tls(&mut stream).await? {
|
||||
if SslMode::Require == mode {
|
||||
return Err(TlsError::Required);
|
||||
}
|
||||
|
||||
return Ok(MaybeTlsStream::Raw(stream));
|
||||
}
|
||||
|
||||
Ok(MaybeTlsStream::Tls(
|
||||
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
|
||||
))
|
||||
}
|
||||
@@ -210,20 +210,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, user_info, _ip_allowlist) = match backend
|
||||
let (node_info, mut auth_info, user_info) = match backend
|
||||
.authenticate(ctx, &config.authentication_config, &mut stream)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
auth_info.set_startup_params(¶ms, true);
|
||||
|
||||
let node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
params_compat: true,
|
||||
params: ¶ms,
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&node_info,
|
||||
|
||||
@@ -261,24 +261,18 @@ impl NeonControlPlaneClient {
|
||||
Some(_) => SslMode::Require,
|
||||
None => SslMode::Disable,
|
||||
};
|
||||
let host_name = match body.server_name {
|
||||
Some(host) => host,
|
||||
None => host.to_owned(),
|
||||
let host = match body.server_name {
|
||||
Some(host) => host.into(),
|
||||
None => host.into(),
|
||||
};
|
||||
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new(host_name, port);
|
||||
|
||||
if let Some(addr) = host_addr {
|
||||
config.set_host_addr(addr);
|
||||
}
|
||||
|
||||
config.ssl_mode(ssl_mode);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
conn_info: compute::ConnectInfo {
|
||||
host_addr,
|
||||
host,
|
||||
port,
|
||||
ssl_mode,
|
||||
},
|
||||
aux: body.aux,
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use postgres_client::config::SslMode;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::Client;
|
||||
use tracing::{Instrument, error, info, info_span, warn};
|
||||
@@ -14,6 +15,7 @@ use crate::auth::IpPattern;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::cache::Cached;
|
||||
use crate::compute::ConnectInfo;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
@@ -24,9 +26,9 @@ use crate::control_plane::{
|
||||
RoleAccessControl,
|
||||
};
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::scram;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{compute, scram};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum MockApiError {
|
||||
@@ -87,8 +89,7 @@ impl MockControlPlane {
|
||||
.await?
|
||||
{
|
||||
info!("got a secret: {entry}"); // safe since it's not a prod scenario
|
||||
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
|
||||
} else {
|
||||
warn!("user '{role}' does not exist");
|
||||
None
|
||||
@@ -170,25 +171,23 @@ impl MockControlPlane {
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let port = self.endpoint.port().unwrap_or(5432);
|
||||
let mut config = match self.endpoint.host_str() {
|
||||
None => {
|
||||
let mut config = compute::ConnCfg::new("localhost".to_string(), port);
|
||||
config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
|
||||
config
|
||||
}
|
||||
Some(host) => {
|
||||
let mut config = compute::ConnCfg::new(host.to_string(), port);
|
||||
if let Ok(addr) = IpAddr::from_str(host) {
|
||||
config.set_host_addr(addr);
|
||||
}
|
||||
config
|
||||
}
|
||||
let conn_info = match self.endpoint.host_str() {
|
||||
None => ConnectInfo {
|
||||
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
host: "localhost".into(),
|
||||
port,
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
Some(host) => ConnectInfo {
|
||||
host_addr: IpAddr::from_str(host).ok(),
|
||||
host: host.into(),
|
||||
port,
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
};
|
||||
|
||||
config.ssl_mode(postgres_client::config::SslMode::Disable);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
conn_info,
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
@@ -266,12 +265,3 @@ impl super::ControlPlaneApi for MockControlPlane {
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_md5(input: &str) -> Option<[u8; 16]> {
|
||||
let text = input.strip_prefix("md5")?;
|
||||
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ pub(crate) mod errors;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
|
||||
use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::ComputeConfig;
|
||||
@@ -39,10 +39,6 @@ pub mod mgmt;
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) enum AuthSecret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
@@ -63,13 +59,9 @@ pub(crate) struct AuthInfo {
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub(crate) config: compute::ConnCfg,
|
||||
pub(crate) conn_info: compute::ConnectInfo,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
@@ -79,26 +71,14 @@ impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
auth: &compute::AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(ctx, self.aux.clone(), config, user_info)
|
||||
self.conn_info
|
||||
.connect(ctx, self.aux.clone(), auth, config, user_info)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default)]
|
||||
|
||||
@@ -610,11 +610,11 @@ pub enum RedisEventsCount {
|
||||
BranchCreated,
|
||||
ProjectCreated,
|
||||
CancelSession,
|
||||
PasswordUpdate,
|
||||
AllowedIpsUpdate,
|
||||
AllowedVpcEndpointIdsUpdateForProjects,
|
||||
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
|
||||
BlockPublicOrVpcAccessUpdate,
|
||||
InvalidateRole,
|
||||
InvalidateEndpoint,
|
||||
InvalidateProject,
|
||||
InvalidateProjects,
|
||||
InvalidateOrg,
|
||||
}
|
||||
|
||||
pub struct ThreadPoolWorkers(usize);
|
||||
|
||||
@@ -2,8 +2,8 @@ use async_trait::async_trait;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
@@ -13,7 +13,6 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::types::Host;
|
||||
@@ -48,8 +47,6 @@ pub(crate) trait ConnectMechanism {
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -58,24 +55,17 @@ pub(crate) trait ComputeConnectBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys;
|
||||
}
|
||||
|
||||
pub(crate) struct TcpMechanism<'a> {
|
||||
pub(crate) params_compat: bool,
|
||||
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub(crate) params: &'a StartupMessageParams,
|
||||
|
||||
pub(crate) struct TcpMechanism {
|
||||
pub(crate) auth: AuthInfo,
|
||||
/// connect_to_compute concurrency lock
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
|
||||
pub(crate) user_info: ComputeUserInfo,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TcpMechanism<'_> {
|
||||
impl ConnectMechanism for TcpMechanism {
|
||||
type Connection = PostgresConnection;
|
||||
type ConnectError = compute::ConnectionError;
|
||||
type Error = compute::ConnectionError;
|
||||
@@ -90,13 +80,12 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await)
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
config.set_startup_params(self.params, self.params_compat);
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
permit.release_result(
|
||||
node_info
|
||||
.connect(ctx, &self.auth, config, self.user_info.clone())
|
||||
.await,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,12 +103,9 @@ where
|
||||
M::Error: From<WakeComputeError>,
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info =
|
||||
let node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
|
||||
node_info.set_keys(user_info.get_keys());
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||
Ok(res) => {
|
||||
@@ -155,14 +141,9 @@ where
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
debug!("compute node's state has likely changed; requesting a wake-up");
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
invalidate_cache(node_info);
|
||||
// TODO: increment num_retries?
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
node_info
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::io::{self, Cursor};
|
||||
use bytes::{Buf, BufMut};
|
||||
use itertools::Itertools;
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
pub type ErrorCode = [u8; 5];
|
||||
@@ -53,6 +53,28 @@ impl fmt::Debug for ProtocolVersion {
|
||||
}
|
||||
}
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
|
||||
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
|
||||
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
|
||||
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
|
||||
|
||||
/// This first reads the startup message header, is 8 bytes.
|
||||
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
|
||||
///
|
||||
/// The length value is inclusive of the header. For example,
|
||||
/// an empty message will always have length 8.
|
||||
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
struct StartupHeader {
|
||||
len: big_endian::U32,
|
||||
version: ProtocolVersion,
|
||||
}
|
||||
|
||||
/// read the type from the stream using zerocopy.
|
||||
///
|
||||
/// not cancel safe.
|
||||
@@ -66,32 +88,38 @@ macro_rules! read {
|
||||
}};
|
||||
}
|
||||
|
||||
/// Returns true if TLS is supported.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let payload = StartupHeader {
|
||||
len: 8.into(),
|
||||
version: NEGOTIATE_SSL_CODE,
|
||||
};
|
||||
stream.write_all(payload.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
// we expect back either `S` or `N` as a single byte.
|
||||
let mut res = *b"0";
|
||||
stream.read_exact(&mut res).await?;
|
||||
|
||||
debug_assert!(
|
||||
res == *b"S" || res == *b"N",
|
||||
"unexpected SSL negotiation response: {}",
|
||||
char::from(res[0]),
|
||||
);
|
||||
|
||||
// S for SSL.
|
||||
Ok(res == *b"S")
|
||||
}
|
||||
|
||||
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
|
||||
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
|
||||
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
|
||||
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
|
||||
|
||||
/// This first reads the startup message header, is 8 bytes.
|
||||
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
|
||||
///
|
||||
/// The length value is inclusive of the header. For example,
|
||||
/// an empty message will always have length 8.
|
||||
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
struct StartupHeader {
|
||||
len: big_endian::U32,
|
||||
version: ProtocolVersion,
|
||||
}
|
||||
|
||||
let header = read!(stream => StartupHeader);
|
||||
|
||||
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
|
||||
@@ -564,9 +592,8 @@ mod tests {
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
|
||||
|
||||
use super::ProtocolVersion;
|
||||
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_large_startup() {
|
||||
|
||||
@@ -358,21 +358,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
}
|
||||
};
|
||||
|
||||
let compute_user_info = match &user_info {
|
||||
auth::Backend::ControlPlane(_, info) => &info.info,
|
||||
let creds = match &user_info {
|
||||
auth::Backend::ControlPlane(_, creds) => creds,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
|
||||
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
|
||||
auth_info.set_startup_params(¶ms, params_compat);
|
||||
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: compute_user_info.clone(),
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
user_info: creds.info.clone(),
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
|
||||
@@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.could_retry(),
|
||||
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
|
||||
compute::ConnectionError::TlsError(err) => err.could_retry(),
|
||||
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
|
||||
_ => false,
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::time::Duration;
|
||||
use anyhow::{Context, bail};
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
|
||||
use postgres_client::tls::{MakeTlsConnect, NoTls};
|
||||
use retry::{ShouldRetryWakeCompute, retry_after};
|
||||
use rstest::rstest;
|
||||
@@ -29,7 +29,6 @@ use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::pglb::connect_compute::ConnectMechanism;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
@@ -72,13 +71,14 @@ struct ClientConfig<'a> {
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
|
||||
type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
|
||||
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
Ok(crate::tls::postgres_rustls::make_tls_connect(
|
||||
&self.config,
|
||||
self.hostname,
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,8 +497,6 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
x => panic!("expecting action {x:?}, connect is called instead"),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
impl TestControlPlaneClient for TestConnectMechanism {
|
||||
@@ -557,7 +555,12 @@ impl TestControlPlaneClient for TestConnectMechanism {
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
config: compute::ConnCfg::new("test".to_owned(), 5432),
|
||||
conn_info: compute::ConnectInfo {
|
||||
host: "test".into(),
|
||||
port: 5432,
|
||||
ssl_mode: SslMode::Disable,
|
||||
host_addr: None,
|
||||
},
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
@@ -581,7 +584,10 @@ fn helper_create_connect_info(
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
|
||||
client_key: [0; 32],
|
||||
server_key: [0; 32],
|
||||
})),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@ use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use redis::aio::PubSub;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::Deserialize;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
@@ -27,42 +27,37 @@ struct NotificationHeader<'a> {
|
||||
topic: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
pub(crate) enum Notification {
|
||||
enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
rename = "/account_settings_update",
|
||||
alias = "/allowed_vpc_endpoints_updated_for_org",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate,
|
||||
},
|
||||
AccountSettingsUpdate(InvalidateAccount),
|
||||
|
||||
#[serde(
|
||||
rename = "/block_public_or_vpc_access_updated",
|
||||
rename = "/endpoint_settings_update",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
BlockPublicOrVpcAccessUpdated {
|
||||
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
|
||||
},
|
||||
EndpointSettingsUpdate(InvalidateEndpoint),
|
||||
|
||||
#[serde(
|
||||
rename = "/allowed_vpc_endpoints_updated_for_org",
|
||||
rename = "/project_settings_update",
|
||||
alias = "/allowed_ips_updated",
|
||||
alias = "/block_public_or_vpc_access_updated",
|
||||
alias = "/allowed_vpc_endpoints_updated_for_projects",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedVpcEndpointsUpdatedForOrg {
|
||||
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
|
||||
},
|
||||
ProjectSettingsUpdate(InvalidateProject),
|
||||
|
||||
#[serde(
|
||||
rename = "/allowed_vpc_endpoints_updated_for_projects",
|
||||
rename = "/role_setting_update",
|
||||
alias = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedVpcEndpointsUpdatedForProjects {
|
||||
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
RoleSettingUpdate(InvalidateRole),
|
||||
|
||||
#[serde(
|
||||
other,
|
||||
@@ -72,28 +67,56 @@ pub(crate) enum Notification {
|
||||
UnknownTopic,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedIpsUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum InvalidateEndpoint {
|
||||
EndpointId(EndpointIdInt),
|
||||
EndpointIds(Vec<EndpointIdInt>),
|
||||
}
|
||||
impl std::ops::Deref for InvalidateEndpoint {
|
||||
type Target = [EndpointIdInt];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::EndpointId(id) => std::slice::from_ref(id),
|
||||
Self::EndpointIds(ids) => ids,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct BlockPublicOrVpcAccessUpdated {
|
||||
project_id: ProjectIdInt,
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum InvalidateProject {
|
||||
ProjectId(ProjectIdInt),
|
||||
ProjectIds(Vec<ProjectIdInt>),
|
||||
}
|
||||
impl std::ops::Deref for InvalidateProject {
|
||||
type Target = [ProjectIdInt];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::ProjectId(id) => std::slice::from_ref(id),
|
||||
Self::ProjectIds(ids) => ids,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
|
||||
account_id: AccountIdInt,
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum InvalidateAccount {
|
||||
AccountId(AccountIdInt),
|
||||
AccountIds(Vec<AccountIdInt>),
|
||||
}
|
||||
impl std::ops::Deref for InvalidateAccount {
|
||||
type Target = [AccountIdInt];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::AccountId(id) => std::slice::from_ref(id),
|
||||
Self::AccountIds(ids) => ids,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
|
||||
project_ids: Vec<ProjectIdInt>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct PasswordUpdate {
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct InvalidateRole {
|
||||
project_id: ProjectIdInt,
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
@@ -177,41 +200,29 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Notification::AllowedIpsUpdate { .. }
|
||||
| Notification::PasswordUpdate { .. }
|
||||
| Notification::BlockPublicOrVpcAccessUpdated { .. }
|
||||
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
|
||||
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
|
||||
Notification::RoleSettingUpdate { .. }
|
||||
| Notification::EndpointSettingsUpdate { .. }
|
||||
| Notification::ProjectSettingsUpdate { .. }
|
||||
| Notification::AccountSettingsUpdate { .. } => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedIpsUpdate);
|
||||
} else if matches!(msg, Notification::PasswordUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::PasswordUpdate);
|
||||
} else if matches!(
|
||||
msg,
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
|
||||
) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
|
||||
} else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
|
||||
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
|
||||
|
||||
let m = &Metrics::get().proxy.redis_events_count;
|
||||
match msg {
|
||||
Notification::RoleSettingUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateRole);
|
||||
}
|
||||
Notification::EndpointSettingsUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateEndpoint);
|
||||
}
|
||||
Notification::ProjectSettingsUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateProject);
|
||||
}
|
||||
Notification::AccountSettingsUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateOrg);
|
||||
}
|
||||
Notification::UnknownTopic => {}
|
||||
}
|
||||
|
||||
// TODO: add additional metrics for the other event types.
|
||||
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
@@ -233,30 +244,23 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
match msg {
|
||||
Notification::AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate { project_id },
|
||||
}
|
||||
| Notification::BlockPublicOrVpcAccessUpdated {
|
||||
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id },
|
||||
} => cache.invalidate_endpoint_access_for_project(project_id),
|
||||
Notification::AllowedVpcEndpointsUpdatedForOrg {
|
||||
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id },
|
||||
} => cache.invalidate_endpoint_access_for_org(account_id),
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects {
|
||||
allowed_vpc_endpoints_updated_for_projects:
|
||||
AllowedVpcEndpointsUpdatedForProjects { project_ids },
|
||||
} => {
|
||||
for project in project_ids {
|
||||
cache.invalidate_endpoint_access_for_project(project);
|
||||
}
|
||||
}
|
||||
Notification::PasswordUpdate {
|
||||
password_update:
|
||||
PasswordUpdate {
|
||||
project_id,
|
||||
role_name,
|
||||
},
|
||||
} => cache.invalidate_role_secret_for_project(project_id, role_name),
|
||||
Notification::EndpointSettingsUpdate(ids) => ids
|
||||
.iter()
|
||||
.for_each(|&id| cache.invalidate_endpoint_access(id)),
|
||||
|
||||
Notification::AccountSettingsUpdate(ids) => ids
|
||||
.iter()
|
||||
.for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
|
||||
|
||||
Notification::ProjectSettingsUpdate(ids) => ids
|
||||
.iter()
|
||||
.for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
|
||||
|
||||
Notification::RoleSettingUpdate(InvalidateRole {
|
||||
project_id,
|
||||
role_name,
|
||||
}) => cache.invalidate_role_secret_for_project(project_id, role_name),
|
||||
|
||||
Notification::UnknownTopic => unreachable!(),
|
||||
}
|
||||
}
|
||||
@@ -353,11 +357,32 @@ mod tests {
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate {
|
||||
project_id: (&project_id).into()
|
||||
}
|
||||
}
|
||||
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multiple_projects() -> anyhow::Result<()> {
|
||||
let project_id1: ProjectId = "new_project1".into();
|
||||
let project_id2: ProjectId = "new_project2".into();
|
||||
let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/allowed_vpc_endpoints_updated_for_projects",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
|
||||
(&project_id1).into(),
|
||||
(&project_id2).into()
|
||||
]))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -379,12 +404,10 @@ mod tests {
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::PasswordUpdate {
|
||||
password_update: PasswordUpdate {
|
||||
project_id: (&project_id).into(),
|
||||
role_name: (&role_name).into(),
|
||||
}
|
||||
}
|
||||
Notification::RoleSettingUpdate(InvalidateRole {
|
||||
project_id: (&project_id).into(),
|
||||
role_name: (&role_name).into(),
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -23,7 +23,6 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, AuthError};
|
||||
use crate::compute;
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
};
|
||||
@@ -305,12 +304,13 @@ impl PoolingBackend {
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let mut node_info = local_backend.node_info.clone();
|
||||
|
||||
let (key, jwk) = create_random_jwk();
|
||||
|
||||
let config = node_info
|
||||
.config
|
||||
let mut config = local_backend
|
||||
.node_info
|
||||
.conn_info
|
||||
.to_postgres_client_config();
|
||||
config
|
||||
.user(&conn_info.user_info.user)
|
||||
.dbname(&conn_info.dbname)
|
||||
.set_param(
|
||||
@@ -322,7 +322,7 @@ impl PoolingBackend {
|
||||
);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(postgres_client::NoTls).await?;
|
||||
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
let pid = client.get_process_id();
|
||||
@@ -336,7 +336,7 @@ impl PoolingBackend {
|
||||
connection,
|
||||
key,
|
||||
conn_id,
|
||||
node_info.aux.clone(),
|
||||
local_backend.node_info.aux.clone(),
|
||||
);
|
||||
|
||||
{
|
||||
@@ -512,19 +512,16 @@ impl ConnectMechanism for TokioMechanism {
|
||||
node_info: &CachedNodeInfo,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
|
||||
let mut config = (*node_info.config).clone();
|
||||
let mut config = node_info.conn_info.to_postgres_client_config();
|
||||
let config = config
|
||||
.user(&self.conn_info.user_info.user)
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(compute_config.timeout);
|
||||
|
||||
let mk_tls =
|
||||
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(mk_tls).await;
|
||||
let res = config.connect(compute_config).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -548,8 +545,6 @@ impl ConnectMechanism for TokioMechanism {
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
@@ -573,20 +568,20 @@ impl ConnectMechanism for HyperMechanism {
|
||||
node_info: &CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host_addr = node_info.config.get_host_addr();
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
let host_addr = node_info.conn_info.host_addr;
|
||||
let host = &node_info.conn_info.host;
|
||||
let permit = self.locks.get_permit(host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
|
||||
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
|
||||
None
|
||||
} else {
|
||||
Some(&config.tls)
|
||||
};
|
||||
|
||||
let port = node_info.config.get_port();
|
||||
let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
|
||||
let port = node_info.conn_info.port;
|
||||
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -609,8 +604,6 @@ impl ConnectMechanism for HyperMechanism {
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
async fn connect_http2(
|
||||
|
||||
@@ -23,12 +23,12 @@ use super::conn_pool_lib::{
|
||||
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
|
||||
GlobalConnPool,
|
||||
};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
|
||||
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ConnInfoWithAuth {
|
||||
|
||||
@@ -2,10 +2,11 @@ use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use rustls::ClientConfig;
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::pki_types::{InvalidDnsNameError, ServerName};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::config::ComputeConfig;
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
@@ -123,36 +124,27 @@ mod private {
|
||||
}
|
||||
}
|
||||
|
||||
/// A `MakeTlsConnect` implementation using `rustls`.
|
||||
///
|
||||
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
||||
#[derive(Clone)]
|
||||
pub struct MakeRustlsConnect {
|
||||
pub config: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
impl MakeRustlsConnect {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
||||
#[must_use]
|
||||
pub fn new(config: Arc<ClientConfig>) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
|
||||
impl<S> MakeTlsConnect<S> for ComputeConfig
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = private::RustlsStream<S>;
|
||||
type TlsConnect = private::RustlsConnect;
|
||||
type Error = rustls::pki_types::InvalidDnsNameError;
|
||||
type Error = InvalidDnsNameError;
|
||||
|
||||
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: Arc::clone(&self.config).into(),
|
||||
})
|
||||
})
|
||||
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
make_tls_connect(&self.tls, hostname)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_tls_connect(
|
||||
tls: &Arc<rustls::ClientConfig>,
|
||||
hostname: &str,
|
||||
) -> Result<private::RustlsConnect, InvalidDnsNameError> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: tls.clone().into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ use std::error::Error as _;
|
||||
use http_utils::error::HttpErrorBody;
|
||||
use reqwest::{IntoUrl, Method, StatusCode};
|
||||
use safekeeper_api::models::{
|
||||
self, PullTimelineRequest, PullTimelineResponse, SafekeeperUtilization, TimelineCreateRequest,
|
||||
TimelineStatus,
|
||||
self, PullTimelineRequest, PullTimelineResponse, SafekeeperStatus, SafekeeperUtilization,
|
||||
TimelineCreateRequest, TimelineStatus,
|
||||
};
|
||||
use utils::id::{NodeId, TenantId, TimelineId};
|
||||
use utils::logging::SecretString;
|
||||
@@ -183,6 +183,12 @@ impl Client {
|
||||
self.get(&uri).await
|
||||
}
|
||||
|
||||
pub async fn status(&self) -> Result<SafekeeperStatus> {
|
||||
let uri = format!("{}/v1/status", self.mgmt_api_endpoint);
|
||||
let resp = self.get(&uri).await?;
|
||||
resp.json().await.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn utilization(&self) -> Result<SafekeeperUtilization> {
|
||||
let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint);
|
||||
let resp = self.get(&uri).await?;
|
||||
|
||||
@@ -395,6 +395,8 @@ pub enum TimelineError {
|
||||
Cancelled(TenantTimelineId),
|
||||
#[error("Timeline {0} was not found in global map")]
|
||||
NotFound(TenantTimelineId),
|
||||
#[error("Timeline {0} has been deleted")]
|
||||
Deleted(TenantTimelineId),
|
||||
#[error("Timeline {0} creation is in progress")]
|
||||
CreationInProgress(TenantTimelineId),
|
||||
#[error("Timeline {0} exists on disk, but wasn't loaded on startup")]
|
||||
|
||||
@@ -78,7 +78,13 @@ impl GlobalTimelinesState {
|
||||
Some(GlobalMapTimeline::CreationInProgress) => {
|
||||
Err(TimelineError::CreationInProgress(*ttid))
|
||||
}
|
||||
None => Err(TimelineError::NotFound(*ttid)),
|
||||
None => {
|
||||
if self.has_tombstone(ttid) {
|
||||
Err(TimelineError::Deleted(*ttid))
|
||||
} else {
|
||||
Err(TimelineError::NotFound(*ttid))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE nodes DROP COLUMN lifecycle;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE nodes ADD COLUMN lifecycle VARCHAR NOT NULL DEFAULT 'active';
|
||||
@@ -907,6 +907,42 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
|
||||
json_response(StatusCode::OK, state.service.node_delete(node_id).await?)
|
||||
}
|
||||
|
||||
async fn handle_tombstone_list(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
|
||||
let req = match maybe_forward(req).await {
|
||||
ForwardOutcome::Forwarded(res) => {
|
||||
return res;
|
||||
}
|
||||
ForwardOutcome::NotForwarded(req) => req,
|
||||
};
|
||||
|
||||
let state = get_state(&req);
|
||||
let mut nodes = state.service.tombstone_list().await?;
|
||||
nodes.sort_by_key(|n| n.get_id());
|
||||
let api_nodes = nodes.into_iter().map(|n| n.describe()).collect::<Vec<_>>();
|
||||
|
||||
json_response(StatusCode::OK, api_nodes)
|
||||
}
|
||||
|
||||
async fn handle_tombstone_delete(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
|
||||
let req = match maybe_forward(req).await {
|
||||
ForwardOutcome::Forwarded(res) => {
|
||||
return res;
|
||||
}
|
||||
ForwardOutcome::NotForwarded(req) => req,
|
||||
};
|
||||
|
||||
let state = get_state(&req);
|
||||
let node_id: NodeId = parse_request_param(&req, "node_id")?;
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
state.service.tombstone_delete(node_id).await?,
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_node_configure(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
|
||||
@@ -2063,6 +2099,20 @@ pub fn make_router(
|
||||
.post("/debug/v1/node/:node_id/drop", |r| {
|
||||
named_request_span(r, handle_node_drop, RequestName("debug_v1_node_drop"))
|
||||
})
|
||||
.delete("/debug/v1/tombstone/:node_id", |r| {
|
||||
named_request_span(
|
||||
r,
|
||||
handle_tombstone_delete,
|
||||
RequestName("debug_v1_tombstone_delete"),
|
||||
)
|
||||
})
|
||||
.get("/debug/v1/tombstone", |r| {
|
||||
named_request_span(
|
||||
r,
|
||||
handle_tombstone_list,
|
||||
RequestName("debug_v1_tombstone_list"),
|
||||
)
|
||||
})
|
||||
.post("/debug/v1/tenant/:tenant_id/import", |r| {
|
||||
named_request_span(
|
||||
r,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
|
||||
use pageserver_api::controller_api::{
|
||||
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeRegisterRequest,
|
||||
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeLifecycle, NodeRegisterRequest,
|
||||
NodeSchedulingPolicy, TenantLocateResponseShard,
|
||||
};
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
@@ -29,6 +29,7 @@ pub(crate) struct Node {
|
||||
|
||||
availability: NodeAvailability,
|
||||
scheduling: NodeSchedulingPolicy,
|
||||
lifecycle: NodeLifecycle,
|
||||
|
||||
listen_http_addr: String,
|
||||
listen_http_port: u16,
|
||||
@@ -228,6 +229,7 @@ impl Node {
|
||||
listen_pg_addr,
|
||||
listen_pg_port,
|
||||
scheduling: NodeSchedulingPolicy::Active,
|
||||
lifecycle: NodeLifecycle::Active,
|
||||
availability: NodeAvailability::Offline,
|
||||
availability_zone_id,
|
||||
use_https,
|
||||
@@ -239,6 +241,7 @@ impl Node {
|
||||
NodePersistence {
|
||||
node_id: self.id.0 as i64,
|
||||
scheduling_policy: self.scheduling.into(),
|
||||
lifecycle: self.lifecycle.into(),
|
||||
listen_http_addr: self.listen_http_addr.clone(),
|
||||
listen_http_port: self.listen_http_port as i32,
|
||||
listen_https_port: self.listen_https_port.map(|x| x as i32),
|
||||
@@ -263,6 +266,7 @@ impl Node {
|
||||
availability: NodeAvailability::Offline,
|
||||
scheduling: NodeSchedulingPolicy::from_str(&np.scheduling_policy)
|
||||
.expect("Bad scheduling policy in DB"),
|
||||
lifecycle: NodeLifecycle::from_str(&np.lifecycle).expect("Bad lifecycle in DB"),
|
||||
listen_http_addr: np.listen_http_addr,
|
||||
listen_http_port: np.listen_http_port as u16,
|
||||
listen_https_port: np.listen_https_port.map(|x| x as u16),
|
||||
|
||||
@@ -19,7 +19,7 @@ use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use itertools::Itertools;
|
||||
use pageserver_api::controller_api::{
|
||||
AvailabilityZone, MetadataHealthRecord, NodeSchedulingPolicy, PlacementPolicy,
|
||||
AvailabilityZone, MetadataHealthRecord, NodeLifecycle, NodeSchedulingPolicy, PlacementPolicy,
|
||||
SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy,
|
||||
};
|
||||
use pageserver_api::models::{ShardImportStatus, TenantConfig};
|
||||
@@ -102,6 +102,7 @@ pub(crate) enum DatabaseOperation {
|
||||
UpdateNode,
|
||||
DeleteNode,
|
||||
ListNodes,
|
||||
ListTombstones,
|
||||
BeginShardSplit,
|
||||
CompleteShardSplit,
|
||||
AbortShardSplit,
|
||||
@@ -357,6 +358,8 @@ impl Persistence {
|
||||
}
|
||||
|
||||
/// When a node is first registered, persist it before using it for anything
|
||||
/// If the provided node_id already exists, it will be error.
|
||||
/// The common case is when a node marked for deletion wants to register.
|
||||
pub(crate) async fn insert_node(&self, node: &Node) -> DatabaseResult<()> {
|
||||
let np = &node.to_persistent();
|
||||
self.with_measured_conn(DatabaseOperation::InsertNode, move |conn| {
|
||||
@@ -373,19 +376,41 @@ impl Persistence {
|
||||
|
||||
/// At startup, populate the list of nodes which our shards may be placed on
|
||||
pub(crate) async fn list_nodes(&self) -> DatabaseResult<Vec<NodePersistence>> {
|
||||
let nodes: Vec<NodePersistence> = self
|
||||
use crate::schema::nodes::dsl::*;
|
||||
|
||||
let result: Vec<NodePersistence> = self
|
||||
.with_measured_conn(DatabaseOperation::ListNodes, move |conn| {
|
||||
Box::pin(async move {
|
||||
Ok(crate::schema::nodes::table
|
||||
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
|
||||
.load::<NodePersistence>(conn)
|
||||
.await?)
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
tracing::info!("list_nodes: loaded {} nodes", nodes.len());
|
||||
tracing::info!("list_nodes: loaded {} nodes", result.len());
|
||||
|
||||
Ok(nodes)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) async fn list_tombstones(&self) -> DatabaseResult<Vec<NodePersistence>> {
|
||||
use crate::schema::nodes::dsl::*;
|
||||
|
||||
let result: Vec<NodePersistence> = self
|
||||
.with_measured_conn(DatabaseOperation::ListTombstones, move |conn| {
|
||||
Box::pin(async move {
|
||||
Ok(crate::schema::nodes::table
|
||||
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
|
||||
.load::<NodePersistence>(conn)
|
||||
.await?)
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
tracing::info!("list_tombstones: loaded {} nodes", result.len());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) async fn update_node<V>(
|
||||
@@ -404,6 +429,7 @@ impl Persistence {
|
||||
Box::pin(async move {
|
||||
let updated = diesel::update(nodes)
|
||||
.filter(node_id.eq(input_node_id.0 as i64))
|
||||
.filter(lifecycle.ne(String::from(NodeLifecycle::Deleted)))
|
||||
.set(values)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
@@ -447,6 +473,57 @@ impl Persistence {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Tombstone is a special state where the node is not deleted from the database,
|
||||
/// but it is not available for usage.
|
||||
/// The main reason for it is to prevent the flaky node to register.
|
||||
pub(crate) async fn set_tombstone(&self, del_node_id: NodeId) -> DatabaseResult<()> {
|
||||
use crate::schema::nodes::dsl::*;
|
||||
self.update_node(
|
||||
del_node_id,
|
||||
lifecycle.eq(String::from(NodeLifecycle::Deleted)),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
|
||||
use crate::schema::nodes::dsl::*;
|
||||
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
|
||||
Box::pin(async move {
|
||||
// You can hard delete a node only if it has a tombstone.
|
||||
// So we need to check if the node has lifecycle set to deleted.
|
||||
let node_to_delete = nodes
|
||||
.filter(node_id.eq(del_node_id.0 as i64))
|
||||
.first::<NodePersistence>(conn)
|
||||
.await
|
||||
.optional()?;
|
||||
|
||||
if let Some(np) = node_to_delete {
|
||||
let lc = NodeLifecycle::from_str(&np.lifecycle).map_err(|e| {
|
||||
DatabaseError::Logical(format!(
|
||||
"Node {} has invalid lifecycle: {}",
|
||||
del_node_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
if lc != NodeLifecycle::Deleted {
|
||||
return Err(DatabaseError::Logical(format!(
|
||||
"Node {} was not soft deleted before, cannot hard delete it",
|
||||
del_node_id
|
||||
)));
|
||||
}
|
||||
|
||||
diesel::delete(nodes)
|
||||
.filter(node_id.eq(del_node_id.0 as i64))
|
||||
.execute(conn)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// At startup, load the high level state for shards, such as their config + policy. This will
|
||||
/// be enriched at runtime with state discovered on pageservers.
|
||||
///
|
||||
@@ -543,21 +620,6 @@ impl Persistence {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
|
||||
use crate::schema::nodes::dsl::*;
|
||||
self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| {
|
||||
Box::pin(async move {
|
||||
diesel::delete(nodes)
|
||||
.filter(node_id.eq(del_node_id.0 as i64))
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient
|
||||
/// batched increment of the generations of all tenants whose generation_pageserver is equal to
|
||||
/// the node that called /re-attach.
|
||||
@@ -571,6 +633,20 @@ impl Persistence {
|
||||
let updated = self
|
||||
.with_measured_conn(DatabaseOperation::ReAttach, move |conn| {
|
||||
Box::pin(async move {
|
||||
// Check if the node is not marked as deleted
|
||||
let deleted_node: i64 = nodes
|
||||
.filter(node_id.eq(input_node_id.0 as i64))
|
||||
.filter(lifecycle.eq(String::from(NodeLifecycle::Deleted)))
|
||||
.count()
|
||||
.get_result(conn)
|
||||
.await?;
|
||||
if deleted_node > 0 {
|
||||
return Err(DatabaseError::Logical(format!(
|
||||
"Node {} is marked as deleted, re-attach is not allowed",
|
||||
input_node_id
|
||||
)));
|
||||
}
|
||||
|
||||
let rows_updated = diesel::update(tenant_shards)
|
||||
.filter(generation_pageserver.eq(input_node_id.0 as i64))
|
||||
.set(generation.eq(generation + 1))
|
||||
@@ -2048,6 +2124,7 @@ pub(crate) struct NodePersistence {
|
||||
pub(crate) listen_pg_port: i32,
|
||||
pub(crate) availability_zone_id: String,
|
||||
pub(crate) listen_https_port: Option<i32>,
|
||||
pub(crate) lifecycle: String,
|
||||
}
|
||||
|
||||
/// Tenant metadata health status that are stored durably.
|
||||
|
||||
@@ -33,6 +33,7 @@ diesel::table! {
|
||||
listen_pg_port -> Int4,
|
||||
availability_zone_id -> Varchar,
|
||||
listen_https_port -> Nullable<Int4>,
|
||||
lifecycle -> Varchar,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,6 +166,7 @@ enum NodeOperations {
|
||||
Register,
|
||||
Configure,
|
||||
Delete,
|
||||
DeleteTombstone,
|
||||
}
|
||||
|
||||
/// The leadership status for the storage controller process.
|
||||
@@ -1107,7 +1108,8 @@ impl Service {
|
||||
observed
|
||||
}
|
||||
|
||||
/// Used during [`Self::startup_reconcile`]: detach a list of unknown-to-us tenants from pageservers.
|
||||
/// Used during [`Self::startup_reconcile`] and shard splits: detach a list of unknown-to-us
|
||||
/// tenants from pageservers.
|
||||
///
|
||||
/// This is safe to run in the background, because if we don't have this TenantShardId in our map of
|
||||
/// tenants, then it is probably something incompletely deleted before: we will not fight with any
|
||||
@@ -6210,7 +6212,11 @@ impl Service {
|
||||
}
|
||||
}
|
||||
|
||||
pausable_failpoint!("shard-split-pre-complete");
|
||||
fail::fail_point!("shard-split-pre-complete", |_| Err(ApiError::Conflict(
|
||||
"failpoint".to_string()
|
||||
)));
|
||||
|
||||
pausable_failpoint!("shard-split-pre-complete-pause");
|
||||
|
||||
// TODO: if the pageserver restarted concurrently with our split API call,
|
||||
// the actual generation of the child shard might differ from the generation
|
||||
@@ -6232,6 +6238,15 @@ impl Service {
|
||||
let (response, child_locations, waiters) =
|
||||
self.tenant_shard_split_commit_inmem(tenant_id, new_shard_count, new_stripe_size);
|
||||
|
||||
// Notify all page servers to detach and clean up the old shards because they will no longer
|
||||
// be needed. This is best-effort: if it fails, it will be cleaned up on a subsequent
|
||||
// Pageserver re-attach/startup.
|
||||
let shards_to_cleanup = targets
|
||||
.iter()
|
||||
.map(|target| (target.parent_id, target.node.get_id()))
|
||||
.collect();
|
||||
self.cleanup_locations(shards_to_cleanup).await;
|
||||
|
||||
// Send compute notifications for all the new shards
|
||||
let mut failed_notifications = Vec::new();
|
||||
for (child_id, child_ps, stripe_size) in child_locations {
|
||||
@@ -6909,7 +6924,7 @@ impl Service {
|
||||
/// detaching or deleting it on pageservers. We do not try and re-schedule any
|
||||
/// tenants that were on this node.
|
||||
pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> {
|
||||
self.persistence.delete_node(node_id).await?;
|
||||
self.persistence.set_tombstone(node_id).await?;
|
||||
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
|
||||
@@ -7033,9 +7048,10 @@ impl Service {
|
||||
// That is safe because in Service::spawn we only use generation_pageserver if it refers to a node
|
||||
// that exists.
|
||||
|
||||
// 2. Actually delete the node from the database and from in-memory state
|
||||
// 2. Actually delete the node from in-memory state and set tombstone to the database
|
||||
// for preventing the node to register again.
|
||||
tracing::info!("Deleting node from database");
|
||||
self.persistence.delete_node(node_id).await?;
|
||||
self.persistence.set_tombstone(node_id).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -7054,6 +7070,35 @@ impl Service {
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
pub(crate) async fn tombstone_list(&self) -> Result<Vec<Node>, ApiError> {
|
||||
self.persistence
|
||||
.list_tombstones()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|np| Node::from_persistent(np, false))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(ApiError::InternalServerError)
|
||||
}
|
||||
|
||||
pub(crate) async fn tombstone_delete(&self, node_id: NodeId) -> Result<(), ApiError> {
|
||||
let _node_lock = trace_exclusive_lock(
|
||||
&self.node_op_locks,
|
||||
node_id,
|
||||
NodeOperations::DeleteTombstone,
|
||||
)
|
||||
.await;
|
||||
|
||||
if matches!(self.get_node(node_id).await, Err(ApiError::NotFound(_))) {
|
||||
self.persistence.delete_node(node_id).await?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ApiError::Conflict(format!(
|
||||
"Node {} is in use, consider using tombstone API first",
|
||||
node_id
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get_node(&self, node_id: NodeId) -> Result<Node, ApiError> {
|
||||
self.inner
|
||||
.read()
|
||||
@@ -7224,7 +7269,25 @@ impl Service {
|
||||
};
|
||||
|
||||
match registration_status {
|
||||
RegistrationStatus::New => self.persistence.insert_node(&new_node).await?,
|
||||
RegistrationStatus::New => {
|
||||
self.persistence.insert_node(&new_node).await.map_err(|e| {
|
||||
if matches!(
|
||||
e,
|
||||
crate::persistence::DatabaseError::Query(
|
||||
diesel::result::Error::DatabaseError(
|
||||
diesel::result::DatabaseErrorKind::UniqueViolation,
|
||||
_,
|
||||
)
|
||||
)
|
||||
) {
|
||||
// The node can be deleted by tombstone API, and not show up in the list of nodes.
|
||||
// If you see this error, check tombstones first.
|
||||
ApiError::Conflict(format!("Node {} is already exists", new_node.get_id()))
|
||||
} else {
|
||||
ApiError::from(e)
|
||||
}
|
||||
})?;
|
||||
}
|
||||
RegistrationStatus::NeedUpdate => {
|
||||
self.persistence
|
||||
.update_node_on_registration(
|
||||
|
||||
@@ -2054,6 +2054,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def tombstone_delete(self, node_id):
|
||||
log.info(f"tombstone_delete({node_id})")
|
||||
self.request(
|
||||
"DELETE",
|
||||
f"{self.api}/debug/v1/tombstone/{node_id}",
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def node_drain(self, node_id):
|
||||
log.info(f"node_drain({node_id})")
|
||||
self.request(
|
||||
@@ -2110,6 +2118,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def tombstone_list(self):
|
||||
response = self.request(
|
||||
"GET",
|
||||
f"{self.api}/debug/v1/tombstone",
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def tenant_shard_dump(self):
|
||||
"""
|
||||
Debug listing API: dumps the internal map of tenant shards
|
||||
|
||||
@@ -87,6 +87,9 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build
|
||||
|
||||
# Set up pageserver for import
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
||||
neon_env_builder.storage_controller_config = {
|
||||
"timelines_onto_safekeepers": True,
|
||||
}
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
env.pageserver.tenant_create(tenant)
|
||||
|
||||
@@ -30,6 +30,7 @@ def test_safekeeper_delete_timeline(neon_env_builder: NeonEnvBuilder, auth_enabl
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*Timeline .* was not found in global map.*",
|
||||
".*Timeline .* has been deleted.*",
|
||||
".*Timeline .* was cancelled and cannot be used anymore.*",
|
||||
]
|
||||
)
|
||||
@@ -198,6 +199,7 @@ def test_safekeeper_delete_timeline_under_load(neon_env_builder: NeonEnvBuilder)
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*Timeline.*was cancelled.*",
|
||||
".*Timeline.*has been deleted.*",
|
||||
".*Timeline.*was not found.*",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1836,3 +1836,90 @@ def test_sharding_gc(
|
||||
shard_gc_cutoff_lsn = Lsn(shard_index["metadata_bytes"]["latest_gc_cutoff_lsn"])
|
||||
log.info(f"Shard {shard_number} cutoff LSN: {shard_gc_cutoff_lsn}")
|
||||
assert shard_gc_cutoff_lsn == shard_0_gc_cutoff_lsn
|
||||
|
||||
|
||||
def test_split_ps_delete_old_shard_after_commit(neon_env_builder: NeonEnvBuilder):
|
||||
"""
|
||||
Check that PageServer only deletes old shards after the split is committed such that it doesn't
|
||||
have to download a lot of files during abort.
|
||||
"""
|
||||
DBNAME = "regression"
|
||||
|
||||
init_shard_count = 4
|
||||
neon_env_builder.num_pageservers = init_shard_count
|
||||
stripe_size = 32
|
||||
|
||||
env = neon_env_builder.init_start(
|
||||
initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size
|
||||
)
|
||||
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[
|
||||
# All split failures log a warning when they enqueue the abort operation
|
||||
".*Enqueuing background abort.*",
|
||||
# Tolerate any error logs that mention a failpoint
|
||||
".*failpoint.*",
|
||||
]
|
||||
)
|
||||
|
||||
endpoint = env.endpoints.create("main")
|
||||
endpoint.respec(skip_pg_catalog_updates=False)
|
||||
endpoint.start()
|
||||
|
||||
# Write some initial data.
|
||||
endpoint.safe_psql(f"CREATE DATABASE {DBNAME}")
|
||||
endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);")
|
||||
|
||||
for _ in range(1000):
|
||||
endpoint.safe_psql(
|
||||
"INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False
|
||||
)
|
||||
|
||||
# Record how many bytes we've downloaded before the split.
|
||||
def collect_downloaded_bytes() -> list[float | None]:
|
||||
downloaded_bytes = []
|
||||
for page_server in env.pageservers:
|
||||
metric = page_server.http_client().get_metric_value(
|
||||
"pageserver_remote_ondemand_downloaded_bytes_total"
|
||||
)
|
||||
downloaded_bytes.append(metric)
|
||||
return downloaded_bytes
|
||||
|
||||
downloaded_bytes_before = collect_downloaded_bytes()
|
||||
|
||||
# Attempt to split the tenant, but fail the split before it completes.
|
||||
env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)"))
|
||||
with pytest.raises(StorageControllerApiException):
|
||||
env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16)
|
||||
|
||||
# Wait until split is aborted.
|
||||
def check_split_is_aborted():
|
||||
tenants = env.storage_controller.tenant_list()
|
||||
assert len(tenants) == 1
|
||||
shards = tenants[0]["shards"]
|
||||
assert len(shards) == 4
|
||||
for shard in shards:
|
||||
assert not shard["is_splitting"]
|
||||
assert not shard["is_reconciling"]
|
||||
|
||||
# Make sure all new shards have been deleted.
|
||||
valid_shards = 0
|
||||
for ps in env.pageservers:
|
||||
for tenant_dir in os.listdir(ps.workdir / "tenants"):
|
||||
try:
|
||||
tenant_shard_id = TenantShardId.parse(tenant_dir)
|
||||
valid_shards += 1
|
||||
assert tenant_shard_id.shard_count == 4
|
||||
except ValueError:
|
||||
log.info(f"{tenant_dir} is not valid tenant shard id")
|
||||
assert valid_shards >= 4
|
||||
|
||||
wait_until(check_split_is_aborted)
|
||||
|
||||
endpoint.safe_psql("SELECT count(*) from usertable;", log_query=False)
|
||||
|
||||
# Make sure we didn't download anything following the aborted split.
|
||||
downloaded_bytes_after = collect_downloaded_bytes()
|
||||
|
||||
assert downloaded_bytes_before == downloaded_bytes_after
|
||||
endpoint.stop_and_destroy()
|
||||
|
||||
@@ -2956,7 +2956,7 @@ def test_storage_controller_leadership_transfer_during_split(
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[".*Unexpected child shard count.*", ".*Enqueuing background abort.*"]
|
||||
)
|
||||
pause_failpoint = "shard-split-pre-complete"
|
||||
pause_failpoint = "shard-split-pre-complete-pause"
|
||||
env.storage_controller.configure_failpoints((pause_failpoint, "pause"))
|
||||
|
||||
split_fut = executor.submit(
|
||||
@@ -3003,7 +3003,7 @@ def test_storage_controller_leadership_transfer_during_split(
|
||||
env.storage_controller.request(
|
||||
"PUT",
|
||||
f"http://127.0.0.1:{storage_controller_1_port}/debug/v1/failpoints",
|
||||
json=[{"name": "shard-split-pre-complete", "actions": "off"}],
|
||||
json=[{"name": pause_failpoint, "actions": "off"}],
|
||||
headers=env.storage_controller.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
@@ -3093,6 +3093,58 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB
|
||||
wait_until(reconfigure_node_again)
|
||||
|
||||
|
||||
def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
|
||||
neon_env_builder.num_pageservers = 3
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
def assert_nodes_count(n: int):
|
||||
nodes = env.storage_controller.node_list()
|
||||
assert len(nodes) == n
|
||||
|
||||
# Nodes count must remain the same before deletion
|
||||
assert_nodes_count(3)
|
||||
|
||||
ps = env.pageservers[0]
|
||||
env.storage_controller.node_delete(ps.id)
|
||||
|
||||
# After deletion, the node count must be reduced
|
||||
assert_nodes_count(2)
|
||||
|
||||
# Running pageserver CLI init in a separate thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
log.info("Restarting tombstoned pageserver...")
|
||||
ps.stop()
|
||||
ps_start_fut = executor.submit(lambda: ps.start(await_active=False))
|
||||
|
||||
# After deleted pageserver restart, the node count must remain the same
|
||||
assert_nodes_count(2)
|
||||
|
||||
tombstones = env.storage_controller.tombstone_list()
|
||||
assert len(tombstones) == 1 and tombstones[0]["id"] == ps.id
|
||||
|
||||
env.storage_controller.tombstone_delete(ps.id)
|
||||
|
||||
tombstones = env.storage_controller.tombstone_list()
|
||||
assert len(tombstones) == 0
|
||||
|
||||
# Wait for the pageserver start operation to complete.
|
||||
# If it fails with an exception, we try restarting the pageserver since the failure
|
||||
# may be due to the storage controller refusing to register the node.
|
||||
# However, if we get a TimeoutError that means the pageserver is completely hung,
|
||||
# which is an unexpected failure mode that we'll let propagate up.
|
||||
try:
|
||||
ps_start_fut.result(timeout=20)
|
||||
except TimeoutError:
|
||||
raise
|
||||
except Exception:
|
||||
log.info("Restarting deleted pageserver...")
|
||||
ps.restart()
|
||||
|
||||
# Finally, the node can be registered again after tombstone is deleted
|
||||
wait_until(lambda: assert_nodes_count(3))
|
||||
|
||||
|
||||
def test_storage_controller_timeline_crud_race(neon_env_builder: NeonEnvBuilder):
|
||||
"""
|
||||
The storage controller is meant to handle the case where a timeline CRUD operation races
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from fixtures.common_types import Lsn, TenantId, TimelineId
|
||||
from fixtures.common_types import Lsn, TenantId, TimelineArchivalState, TimelineId
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.metrics import (
|
||||
PAGESERVER_GLOBAL_METRICS,
|
||||
@@ -299,6 +299,65 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde
|
||||
assert post_detach_samples == set()
|
||||
|
||||
|
||||
def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder):
|
||||
"""Tests that when a timeline is offloaded, the tenant specific metrics are not left behind"""
|
||||
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
|
||||
|
||||
neon_env_builder.num_safekeepers = 3
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
tenant_1, _ = env.create_tenant()
|
||||
|
||||
timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1)
|
||||
timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1)
|
||||
|
||||
endpoint_tenant1 = env.endpoints.create_start(
|
||||
"test_metrics_removed_after_offload_1", tenant_id=tenant_1
|
||||
)
|
||||
endpoint_tenant2 = env.endpoints.create_start(
|
||||
"test_metrics_removed_after_offload_2", tenant_id=tenant_1
|
||||
)
|
||||
|
||||
for endpoint in [endpoint_tenant1, endpoint_tenant2]:
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CREATE TABLE t(key int primary key, value text)")
|
||||
cur.execute("INSERT INTO t SELECT generate_series(1,100000), 'payload'")
|
||||
cur.execute("SELECT sum(key) FROM t")
|
||||
assert cur.fetchone() == (5000050000,)
|
||||
endpoint.stop()
|
||||
|
||||
def get_ps_metric_samples_for_timeline(
|
||||
tenant_id: TenantId, timeline_id: TimelineId
|
||||
) -> list[Sample]:
|
||||
ps_metrics = env.pageserver.http_client().get_metrics()
|
||||
samples = []
|
||||
for metric_name in ps_metrics.metrics:
|
||||
for sample in ps_metrics.query_all(
|
||||
name=metric_name,
|
||||
filter={"tenant_id": str(tenant_id), "timeline_id": str(timeline_id)},
|
||||
):
|
||||
samples.append(sample)
|
||||
return samples
|
||||
|
||||
for timeline in [timeline_1, timeline_2]:
|
||||
pre_offload_samples = set(
|
||||
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
|
||||
)
|
||||
assert len(pre_offload_samples) > 0, f"expected at least one sample for {timeline}"
|
||||
env.pageserver.http_client().timeline_archival_config(
|
||||
tenant_1,
|
||||
timeline,
|
||||
state=TimelineArchivalState.ARCHIVED,
|
||||
)
|
||||
env.pageserver.http_client().timeline_offload(tenant_1, timeline)
|
||||
post_offload_samples = set(
|
||||
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
|
||||
)
|
||||
assert post_offload_samples == set()
|
||||
|
||||
|
||||
def test_pageserver_with_empty_tenants(neon_env_builder: NeonEnvBuilder):
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
|
||||
@@ -433,6 +433,7 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder):
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*Timeline .* was not found in global map.*",
|
||||
".*Timeline .* has been deleted.*",
|
||||
".*Timeline .* was cancelled and cannot be used anymore.*",
|
||||
]
|
||||
)
|
||||
@@ -1934,6 +1935,7 @@ def test_membership_api(neon_env_builder: NeonEnvBuilder):
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*Timeline .* was not found in global map.*",
|
||||
".*Timeline .* has been deleted.*",
|
||||
".*Timeline .* was cancelled and cannot be used anymore.*",
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user