mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-29 08:10:38 +00:00
Compare commits
39 Commits
quantumish
...
diko/baseb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a57a85fc0 | ||
|
|
385324ee8a | ||
|
|
8a68d463f6 | ||
|
|
3046c307da | ||
|
|
e83f1d8ba5 | ||
|
|
8917676e86 | ||
|
|
43acabd4c2 | ||
|
|
db24ba95d1 | ||
|
|
1dce65308d | ||
|
|
ad88ec9257 | ||
|
|
60dfdf39c7 | ||
|
|
3d5e2bf685 | ||
|
|
54fdcfdfa8 | ||
|
|
28e882a80f | ||
|
|
24038033bf | ||
|
|
1b935b1958 | ||
|
|
3f16ca2c18 | ||
|
|
67b94c5992 | ||
|
|
e38193c530 | ||
|
|
21949137ed | ||
|
|
02f94edb60 | ||
|
|
58327ef74d | ||
|
|
73be6bb736 | ||
|
|
40d7583906 | ||
|
|
7a68699abb | ||
|
|
f42d44342d | ||
|
|
d759fcb8bd | ||
|
|
76f95f06d8 | ||
|
|
7efd4554ab | ||
|
|
3c7235669a | ||
|
|
6dd84041a1 | ||
|
|
df7e301a54 | ||
|
|
470c7d5e0e | ||
|
|
4d99b6ff4d | ||
|
|
590301df08 | ||
|
|
c511786548 | ||
|
|
fe31baf985 | ||
|
|
b23e75ebfe | ||
|
|
24d7c37e6e |
45
Cargo.lock
generated
45
Cargo.lock
generated
@@ -753,6 +753,7 @@ dependencies = [
|
||||
"axum",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http 1.1.0",
|
||||
@@ -761,6 +762,8 @@ dependencies = [
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"serde",
|
||||
"serde_html_form",
|
||||
"serde_path_to_error",
|
||||
"tower 0.5.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -900,12 +903,6 @@ version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
@@ -1297,7 +1294,7 @@ dependencies = [
|
||||
"aws-smithy-types",
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"base64 0.13.1",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"camino",
|
||||
"cfg-if",
|
||||
@@ -1423,7 +1420,7 @@ name = "control_plane"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.13.1",
|
||||
"base64 0.22.1",
|
||||
"camino",
|
||||
"clap",
|
||||
"comfy-table",
|
||||
@@ -1445,6 +1442,7 @@ dependencies = [
|
||||
"regex",
|
||||
"reqwest",
|
||||
"safekeeper_api",
|
||||
"safekeeper_client",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -2054,6 +2052,7 @@ dependencies = [
|
||||
"axum-extra",
|
||||
"camino",
|
||||
"camino-tempfile",
|
||||
"clap",
|
||||
"futures",
|
||||
"http-body-util",
|
||||
"itertools 0.10.5",
|
||||
@@ -4813,7 +4812,7 @@ dependencies = [
|
||||
name = "postgres-protocol2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"base64 0.20.0",
|
||||
"base64 0.22.1",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
@@ -5185,7 +5184,7 @@ dependencies = [
|
||||
"aws-config",
|
||||
"aws-sdk-iam",
|
||||
"aws-sigv4",
|
||||
"base64 0.13.1",
|
||||
"base64 0.22.1",
|
||||
"bstr",
|
||||
"bytes",
|
||||
"camino",
|
||||
@@ -6420,6 +6419,19 @@ dependencies = [
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_html_form"
|
||||
version = "0.2.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"indexmap 2.9.0",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.125"
|
||||
@@ -6476,15 +6488,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_with"
|
||||
version = "2.3.3"
|
||||
version = "3.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07ff71d2c147a7b57362cead5e22f772cd52f6ab31cfcd9edcd7f6aeb2a0afbe"
|
||||
checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"hex",
|
||||
"indexmap 1.9.3",
|
||||
"indexmap 2.9.0",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"serde_with_macros",
|
||||
"time",
|
||||
@@ -6492,9 +6506,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_with_macros"
|
||||
version = "2.3.3"
|
||||
version = "3.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f"
|
||||
checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
@@ -8565,7 +8579,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-core",
|
||||
"base64 0.13.1",
|
||||
"base64 0.21.7",
|
||||
"base64ct",
|
||||
"bytes",
|
||||
|
||||
@@ -71,8 +71,8 @@ aws-credential-types = "1.2.0"
|
||||
aws-sigv4 = { version = "1.2", features = ["sign-http"] }
|
||||
aws-types = "1.3"
|
||||
axum = { version = "0.8.1", features = ["ws"] }
|
||||
axum-extra = { version = "0.10.0", features = ["typed-header"] }
|
||||
base64 = "0.13.0"
|
||||
axum-extra = { version = "0.10.0", features = ["typed-header", "query"] }
|
||||
base64 = "0.22"
|
||||
bincode = "1.3"
|
||||
bindgen = "0.71"
|
||||
bit_field = "0.10.2"
|
||||
@@ -171,7 +171,7 @@ sentry = { version = "0.37", default-features = false, features = ["backtrace",
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_path_to_error = "0.1"
|
||||
serde_with = { version = "2.0", features = [ "base64" ] }
|
||||
serde_with = { version = "3", features = [ "base64" ] }
|
||||
serde_assert = "0.5.0"
|
||||
sha2 = "0.10.2"
|
||||
signal-hook = "0.3"
|
||||
|
||||
13
Dockerfile
13
Dockerfile
@@ -110,6 +110,19 @@ RUN set -e \
|
||||
# System postgres for use with client libraries (e.g. in storage controller)
|
||||
postgresql-15 \
|
||||
openssl \
|
||||
unzip \
|
||||
curl \
|
||||
&& ARCH=$(uname -m) \
|
||||
&& if [ "$ARCH" = "x86_64" ]; then \
|
||||
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \
|
||||
elif [ "$ARCH" = "aarch64" ]; then \
|
||||
curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $ARCH" && exit 1; \
|
||||
fi \
|
||||
&& unzip awscliv2.zip \
|
||||
&& ./aws/install \
|
||||
&& rm -rf aws awscliv2.zip \
|
||||
&& rm -f /etc/apt/apt.conf.d/80-retries \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
|
||||
&& useradd -d /data neon \
|
||||
|
||||
@@ -785,7 +785,7 @@ impl ComputeNode {
|
||||
self.spawn_extension_stats_task();
|
||||
|
||||
if pspec.spec.autoprewarm {
|
||||
self.prewarm_lfc();
|
||||
self.prewarm_lfc(None);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -25,11 +25,16 @@ struct EndpointStoragePair {
|
||||
}
|
||||
|
||||
const KEY: &str = "lfc_state";
|
||||
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
|
||||
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
|
||||
bail!("pspec.endpoint_id missing")
|
||||
impl EndpointStoragePair {
|
||||
/// endpoint_id is set to None while prewarming from other endpoint, see replica promotion
|
||||
/// If not None, takes precedence over pspec.spec.endpoint_id
|
||||
fn from_spec_and_endpoint(
|
||||
pspec: &crate::compute::ParsedSpec,
|
||||
endpoint_id: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref());
|
||||
let Some(ref endpoint_id) = endpoint_id else {
|
||||
bail!("pspec.endpoint_id missing, other endpoint_id not provided")
|
||||
};
|
||||
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
|
||||
bail!("pspec.endpoint_storage_addr missing")
|
||||
@@ -84,7 +89,7 @@ impl ComputeNode {
|
||||
}
|
||||
|
||||
/// Returns false if there is a prewarm request ongoing, true otherwise
|
||||
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
|
||||
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
|
||||
crate::metrics::LFC_PREWARM_REQUESTS.inc();
|
||||
{
|
||||
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
|
||||
@@ -97,7 +102,7 @@ impl ComputeNode {
|
||||
|
||||
let cloned = self.clone();
|
||||
spawn(async move {
|
||||
let Err(err) = cloned.prewarm_impl().await else {
|
||||
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
|
||||
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
|
||||
return;
|
||||
};
|
||||
@@ -109,13 +114,14 @@ impl ComputeNode {
|
||||
true
|
||||
}
|
||||
|
||||
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
|
||||
/// from_endpoint: None for endpoint managed by this compute_ctl
|
||||
fn endpoint_storage_pair(&self, from_endpoint: Option<String>) -> Result<EndpointStoragePair> {
|
||||
let state = self.state.lock().unwrap();
|
||||
state.pspec.as_ref().unwrap().try_into()
|
||||
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
|
||||
}
|
||||
|
||||
async fn prewarm_impl(&self) -> Result<()> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
|
||||
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
|
||||
info!(%url, "requesting LFC state from endpoint storage");
|
||||
|
||||
let request = Client::new().get(&url).bearer_auth(token);
|
||||
@@ -173,7 +179,7 @@ impl ComputeNode {
|
||||
}
|
||||
|
||||
async fn offload_lfc_impl(&self) -> Result<()> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
|
||||
info!(%url, "requesting LFC state from postgres");
|
||||
|
||||
let mut compressed = Vec::new();
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress;
|
||||
use crate::http::JsonResponse;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{Json, http::StatusCode};
|
||||
use axum_extra::extract::OptionalQuery;
|
||||
use compute_api::responses::LfcOffloadState;
|
||||
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
|
||||
|
||||
@@ -16,8 +17,16 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadS
|
||||
Json(compute.lfc_offload_state())
|
||||
}
|
||||
|
||||
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
|
||||
if compute.prewarm_lfc() {
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct PrewarmQuery {
|
||||
pub from_endpoint: String,
|
||||
}
|
||||
|
||||
pub(in crate::http) async fn prewarm(
|
||||
compute: Compute,
|
||||
OptionalQuery(query): OptionalQuery<PrewarmQuery>,
|
||||
) -> Response {
|
||||
if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) {
|
||||
StatusCode::ACCEPTED.into_response()
|
||||
} else {
|
||||
JsonResponse::error(
|
||||
|
||||
@@ -36,6 +36,7 @@ pageserver_api.workspace = true
|
||||
pageserver_client.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
safekeeper_api.workspace = true
|
||||
safekeeper_client.workspace = true
|
||||
postgres_connection.workspace = true
|
||||
storage_broker.workspace = true
|
||||
http-utils.workspace = true
|
||||
|
||||
@@ -45,7 +45,7 @@ use pageserver_api::models::{
|
||||
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
|
||||
use postgres_backend::AuthType;
|
||||
use postgres_connection::parse_host_port;
|
||||
use safekeeper_api::membership::SafekeeperGeneration;
|
||||
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
|
||||
use safekeeper_api::{
|
||||
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
|
||||
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
|
||||
@@ -1255,6 +1255,45 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
|
||||
pageserver
|
||||
.timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version)
|
||||
.await?;
|
||||
if env.storage_controller.timelines_onto_safekeepers {
|
||||
println!("Creating timeline on safekeeper ...");
|
||||
let timeline_info = pageserver
|
||||
.timeline_info(
|
||||
TenantShardId::unsharded(tenant_id),
|
||||
timeline_id,
|
||||
pageserver_client::mgmt_api::ForceAwaitLogicalSize::No,
|
||||
)
|
||||
.await?;
|
||||
let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap());
|
||||
let default_host = default_sk
|
||||
.conf
|
||||
.listen_addr
|
||||
.clone()
|
||||
.unwrap_or_else(|| "localhost".to_string());
|
||||
let mconf = safekeeper_api::membership::Configuration {
|
||||
generation: SafekeeperGeneration::new(1),
|
||||
members: safekeeper_api::membership::MemberSet {
|
||||
m: vec![SafekeeperId {
|
||||
host: default_host,
|
||||
id: default_sk.conf.id,
|
||||
pg_port: default_sk.conf.pg_port,
|
||||
}],
|
||||
},
|
||||
new_members: None,
|
||||
};
|
||||
let pg_version = args.pg_version * 10000;
|
||||
let req = safekeeper_api::models::TimelineCreateRequest {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
mconf,
|
||||
pg_version,
|
||||
system_id: None,
|
||||
wal_seg_size: None,
|
||||
start_lsn: timeline_info.last_record_lsn,
|
||||
commit_lsn: None,
|
||||
};
|
||||
default_sk.create_timeline(&req).await?;
|
||||
}
|
||||
env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?;
|
||||
println!("Done");
|
||||
}
|
||||
|
||||
@@ -45,6 +45,8 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use base64::Engine;
|
||||
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
|
||||
use compute_api::requests::{
|
||||
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
|
||||
};
|
||||
@@ -164,7 +166,7 @@ impl ComputeControlPlane {
|
||||
public_key_use: Some(PublicKeyUse::Signature),
|
||||
key_operations: Some(vec![KeyOperations::Verify]),
|
||||
key_algorithm: Some(KeyAlgorithm::EdDSA),
|
||||
key_id: Some(base64::encode_config(key_hash, base64::URL_SAFE_NO_PAD)),
|
||||
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
|
||||
x509_url: None::<String>,
|
||||
x509_chain: None::<Vec<String>>,
|
||||
x509_sha1_fingerprint: None::<String>,
|
||||
@@ -173,7 +175,7 @@ impl ComputeControlPlane {
|
||||
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
|
||||
key_type: OctetKeyPairType::OctetKeyPair,
|
||||
curve: EllipticCurve::Ed25519,
|
||||
x: base64::encode_config(public_key, base64::URL_SAFE_NO_PAD),
|
||||
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
|
||||
}),
|
||||
}],
|
||||
})
|
||||
|
||||
@@ -635,4 +635,16 @@ impl PageServerNode {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
pub async fn timeline_info(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
timeline_id: TimelineId,
|
||||
force_await_logical_size: mgmt_api::ForceAwaitLogicalSize,
|
||||
) -> anyhow::Result<TimelineInfo> {
|
||||
let timeline_info = self
|
||||
.http_client
|
||||
.timeline_info(tenant_shard_id, timeline_id, force_await_logical_size)
|
||||
.await?;
|
||||
Ok(timeline_info)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
//! .neon/safekeepers/<safekeeper id>
|
||||
//! ```
|
||||
use std::error::Error as _;
|
||||
use std::future::Future;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
@@ -14,9 +13,9 @@ use std::{io, result};
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8PathBuf;
|
||||
use http_utils::error::HttpErrorBody;
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use reqwest::{IntoUrl, Method};
|
||||
use safekeeper_api::models::TimelineCreateRequest;
|
||||
use safekeeper_client::mgmt_api;
|
||||
use thiserror::Error;
|
||||
use utils::auth::{Claims, Scope};
|
||||
use utils::id::NodeId;
|
||||
@@ -35,25 +34,14 @@ pub enum SafekeeperHttpError {
|
||||
|
||||
type Result<T> = result::Result<T, SafekeeperHttpError>;
|
||||
|
||||
pub(crate) trait ResponseErrorMessageExt: Sized {
|
||||
fn error_from_body(self) -> impl Future<Output = Result<Self>> + Send;
|
||||
}
|
||||
|
||||
impl ResponseErrorMessageExt for reqwest::Response {
|
||||
async fn error_from_body(self) -> Result<Self> {
|
||||
let status = self.status();
|
||||
if !(status.is_client_error() || status.is_server_error()) {
|
||||
return Ok(self);
|
||||
}
|
||||
|
||||
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
|
||||
let url = self.url().to_owned();
|
||||
Err(SafekeeperHttpError::Response(
|
||||
match self.json::<HttpErrorBody>().await {
|
||||
Ok(err_body) => format!("Error: {}", err_body.msg),
|
||||
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
|
||||
},
|
||||
))
|
||||
fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError {
|
||||
use mgmt_api::Error::*;
|
||||
match err {
|
||||
ApiError(_, str) => SafekeeperHttpError::Response(str),
|
||||
Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()),
|
||||
ReceiveBody(err) => SafekeeperHttpError::Transport(err),
|
||||
ReceiveErrorBody(err) => SafekeeperHttpError::Response(err),
|
||||
Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,9 +58,8 @@ pub struct SafekeeperNode {
|
||||
|
||||
pub pg_connection_config: PgConnectionConfig,
|
||||
pub env: LocalEnv,
|
||||
pub http_client: reqwest::Client,
|
||||
pub http_client: mgmt_api::Client,
|
||||
pub listen_addr: String,
|
||||
pub http_base_url: String,
|
||||
}
|
||||
|
||||
impl SafekeeperNode {
|
||||
@@ -82,13 +69,14 @@ impl SafekeeperNode {
|
||||
} else {
|
||||
"127.0.0.1".to_string()
|
||||
};
|
||||
let jwt = None;
|
||||
let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port);
|
||||
SafekeeperNode {
|
||||
id: conf.id,
|
||||
conf: conf.clone(),
|
||||
pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port),
|
||||
env: env.clone(),
|
||||
http_client: env.create_http_client(),
|
||||
http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port),
|
||||
http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt),
|
||||
listen_addr,
|
||||
}
|
||||
}
|
||||
@@ -278,20 +266,19 @@ impl SafekeeperNode {
|
||||
)
|
||||
}
|
||||
|
||||
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> reqwest::RequestBuilder {
|
||||
// TODO: authentication
|
||||
//if self.env.auth_type == AuthType::NeonJWT {
|
||||
// builder = builder.bearer_auth(&self.env.safekeeper_auth_token)
|
||||
//}
|
||||
self.http_client.request(method, url)
|
||||
pub async fn check_status(&self) -> Result<()> {
|
||||
self.http_client
|
||||
.status()
|
||||
.await
|
||||
.map_err(err_from_client_err)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn check_status(&self) -> Result<()> {
|
||||
self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
|
||||
.send()
|
||||
.await?
|
||||
.error_from_body()
|
||||
.await?;
|
||||
pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> {
|
||||
self.http_client
|
||||
.create_timeline(req)
|
||||
.await
|
||||
.map_err(err_from_client_err)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,10 +61,16 @@ enum Command {
|
||||
#[arg(long)]
|
||||
scheduling: Option<NodeSchedulingPolicy>,
|
||||
},
|
||||
// Set a node status as deleted.
|
||||
NodeDelete {
|
||||
#[arg(long)]
|
||||
node_id: NodeId,
|
||||
},
|
||||
/// Delete a tombstone of node from the storage controller.
|
||||
NodeDeleteTombstone {
|
||||
#[arg(long)]
|
||||
node_id: NodeId,
|
||||
},
|
||||
/// Modify a tenant's policies in the storage controller
|
||||
TenantPolicy {
|
||||
#[arg(long)]
|
||||
@@ -82,6 +88,8 @@ enum Command {
|
||||
},
|
||||
/// List nodes known to the storage controller
|
||||
Nodes {},
|
||||
/// List soft deleted nodes known to the storage controller
|
||||
NodeTombstones {},
|
||||
/// List tenants known to the storage controller
|
||||
Tenants {
|
||||
/// If this field is set, it will list the tenants on a specific node
|
||||
@@ -900,6 +908,39 @@ async fn main() -> anyhow::Result<()> {
|
||||
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
|
||||
.await?;
|
||||
}
|
||||
Command::NodeDeleteTombstone { node_id } => {
|
||||
storcon_client
|
||||
.dispatch::<(), ()>(
|
||||
Method::DELETE,
|
||||
format!("debug/v1/tombstone/{node_id}"),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::NodeTombstones {} => {
|
||||
let mut resp = storcon_client
|
||||
.dispatch::<(), Vec<NodeDescribeResponse>>(
|
||||
Method::GET,
|
||||
"debug/v1/tombstone".to_string(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr));
|
||||
|
||||
let mut table = comfy_table::Table::new();
|
||||
table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]);
|
||||
for node in resp {
|
||||
table.add_row([
|
||||
format!("{}", node.id),
|
||||
node.listen_http_addr,
|
||||
node.availability_zone_id,
|
||||
format!("{:?}", node.scheduling),
|
||||
format!("{:?}", node.availability),
|
||||
]);
|
||||
}
|
||||
println!("{table}");
|
||||
}
|
||||
Command::TenantSetTimeBasedEviction {
|
||||
tenant_id,
|
||||
period,
|
||||
|
||||
@@ -8,6 +8,7 @@ anyhow.workspace = true
|
||||
axum-extra.workspace = true
|
||||
axum.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
prometheus.workspace = true
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
//! for large computes.
|
||||
mod app;
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use tracing::info;
|
||||
use utils::logging;
|
||||
|
||||
@@ -12,9 +14,26 @@ const fn max_upload_file_limit() -> usize {
|
||||
100 * 1024 * 1024
|
||||
}
|
||||
|
||||
const fn listen() -> SocketAddr {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(exclusive = true)]
|
||||
config_file: Option<String>,
|
||||
#[arg(long, default_value = "false", requires = "config")]
|
||||
/// to allow testing k8s helm chart where we don't have s3 credentials
|
||||
no_s3_check_on_startup: bool,
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// inline config mode for k8s helm chart
|
||||
config: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
struct Config {
|
||||
#[serde(default = "listen")]
|
||||
listen: std::net::SocketAddr,
|
||||
pemfile: camino::Utf8PathBuf,
|
||||
#[serde(flatten)]
|
||||
@@ -31,13 +50,18 @@ async fn main() -> anyhow::Result<()> {
|
||||
logging::Output::Stdout,
|
||||
)?;
|
||||
|
||||
let config: String = std::env::args().skip(1).take(1).collect();
|
||||
if config.is_empty() {
|
||||
anyhow::bail!("Usage: endpoint_storage config.json")
|
||||
}
|
||||
info!("Reading config from {config}");
|
||||
let config = std::fs::read_to_string(config.clone())?;
|
||||
let config: Config = serde_json::from_str(&config).context("parsing config")?;
|
||||
let args = Args::parse();
|
||||
let config: Config = if let Some(config_path) = args.config_file {
|
||||
info!("Reading config from {config_path}");
|
||||
let config = std::fs::read_to_string(config_path)?;
|
||||
serde_json::from_str(&config).context("parsing config")?
|
||||
} else if let Some(config) = args.config {
|
||||
info!("Reading inline config");
|
||||
serde_json::from_str(&config).context("parsing config")?
|
||||
} else {
|
||||
anyhow::bail!("Supply either config file path or --config=inline-config");
|
||||
};
|
||||
|
||||
info!("Reading pemfile from {}", config.pemfile.clone());
|
||||
let pemfile = std::fs::read(config.pemfile.clone())?;
|
||||
info!("Loading public key from {}", config.pemfile.clone());
|
||||
@@ -48,7 +72,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
|
||||
let cancel = tokio_util::sync::CancellationToken::new();
|
||||
app::check_storage_permissions(&storage, cancel.clone()).await?;
|
||||
if !args.no_s3_check_on_startup {
|
||||
app::check_storage_permissions(&storage, cancel.clone()).await?;
|
||||
}
|
||||
|
||||
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
|
||||
auth,
|
||||
|
||||
@@ -344,6 +344,35 @@ impl Default for ShardSchedulingPolicy {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
|
||||
pub enum NodeLifecycle {
|
||||
Active,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
impl FromStr for NodeLifecycle {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"active" => Ok(Self::Active),
|
||||
"deleted" => Ok(Self::Deleted),
|
||||
_ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NodeLifecycle> for String {
|
||||
fn from(value: NodeLifecycle) -> String {
|
||||
use NodeLifecycle::*;
|
||||
match value {
|
||||
Active => "active",
|
||||
Deleted => "deleted",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
|
||||
pub enum NodeSchedulingPolicy {
|
||||
Active,
|
||||
|
||||
@@ -9,7 +9,7 @@ use utils::id::{NodeId, TimelineId};
|
||||
|
||||
use crate::controller_api::NodeRegisterRequest;
|
||||
use crate::models::{LocationConfigMode, ShardImportStatus};
|
||||
use crate::shard::TenantShardId;
|
||||
use crate::shard::{ShardStripeSize, TenantShardId};
|
||||
|
||||
/// Upcall message sent by the pageserver to the configured `control_plane_api` on
|
||||
/// startup.
|
||||
@@ -36,6 +36,10 @@ pub struct ReAttachResponseTenant {
|
||||
/// Default value only for backward compat: this field should be set
|
||||
#[serde(default = "default_mode")]
|
||||
pub mode: LocationConfigMode,
|
||||
|
||||
// Default value only for backward compat: this field should be set
|
||||
#[serde(default = "ShardStripeSize::default")]
|
||||
pub stripe_size: ShardStripeSize,
|
||||
}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ReAttachResponse {
|
||||
|
||||
@@ -55,9 +55,16 @@ impl FeatureResolverBackgroundLoop {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let feature_store = FeatureStore::new_with_flags(resp.flags);
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
tracing::info!("Feature flag updated");
|
||||
let project_id = this.posthog_client.config.project_id.parse::<u64>().ok();
|
||||
match FeatureStore::new_with_flags(resp.flags, project_id) {
|
||||
Ok(feature_store) => {
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
tracing::info!("Feature flag updated");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Cannot process feature flag spec: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!("PostHog feature resolver stopped");
|
||||
}
|
||||
|
||||
@@ -39,6 +39,9 @@ pub struct LocalEvaluationResponse {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LocalEvaluationFlag {
|
||||
#[allow(dead_code)]
|
||||
id: u64,
|
||||
team_id: u64,
|
||||
key: String,
|
||||
filters: LocalEvaluationFlagFilters,
|
||||
active: bool,
|
||||
@@ -107,17 +110,32 @@ impl FeatureStore {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_flags(flags: Vec<LocalEvaluationFlag>) -> Self {
|
||||
pub fn new_with_flags(
|
||||
flags: Vec<LocalEvaluationFlag>,
|
||||
project_id: Option<u64>,
|
||||
) -> Result<Self, &'static str> {
|
||||
let mut store = Self::new();
|
||||
store.set_flags(flags);
|
||||
store
|
||||
store.set_flags(flags, project_id)?;
|
||||
Ok(store)
|
||||
}
|
||||
|
||||
pub fn set_flags(&mut self, flags: Vec<LocalEvaluationFlag>) {
|
||||
pub fn set_flags(
|
||||
&mut self,
|
||||
flags: Vec<LocalEvaluationFlag>,
|
||||
project_id: Option<u64>,
|
||||
) -> Result<(), &'static str> {
|
||||
self.flags.clear();
|
||||
for flag in flags {
|
||||
if let Some(project_id) = project_id {
|
||||
if flag.team_id != project_id {
|
||||
return Err(
|
||||
"Retrieved a spec with different project id, wrong config? Discarding the feature flags.",
|
||||
);
|
||||
}
|
||||
}
|
||||
self.flags.insert(flag.key.clone(), flag);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate a consistent hash for a user ID (e.g., tenant ID).
|
||||
@@ -534,6 +552,13 @@ impl PostHogClient {
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the server API key is a feature flag secure API key. This key can only be
|
||||
/// used to fetch the feature flag specs and can only be used on a undocumented API
|
||||
/// endpoint.
|
||||
fn is_feature_flag_secure_api_key(&self) -> bool {
|
||||
self.config.server_api_key.starts_with("phs_")
|
||||
}
|
||||
|
||||
/// Fetch the feature flag specs from the server.
|
||||
///
|
||||
/// This is unfortunately an undocumented API at:
|
||||
@@ -547,10 +572,22 @@ impl PostHogClient {
|
||||
) -> anyhow::Result<LocalEvaluationResponse> {
|
||||
// BASE_URL/api/projects/:project_id/feature_flags/local_evaluation
|
||||
// with bearer token of self.server_api_key
|
||||
let url = format!(
|
||||
"{}/api/projects/{}/feature_flags/local_evaluation",
|
||||
self.config.private_api_url, self.config.project_id
|
||||
);
|
||||
// OR
|
||||
// BASE_URL/api/feature_flag/local_evaluation/
|
||||
// with bearer token of feature flag specific self.server_api_key
|
||||
let url = if self.is_feature_flag_secure_api_key() {
|
||||
// The new feature local evaluation secure API token
|
||||
format!(
|
||||
"{}/api/feature_flag/local_evaluation",
|
||||
self.config.private_api_url
|
||||
)
|
||||
} else {
|
||||
// The old personal API token
|
||||
format!(
|
||||
"{}/api/projects/{}/feature_flags/local_evaluation",
|
||||
self.config.private_api_url, self.config.project_id
|
||||
)
|
||||
};
|
||||
let response = self
|
||||
.client
|
||||
.get(url)
|
||||
@@ -803,7 +840,7 @@ mod tests {
|
||||
fn evaluate_multivariate() {
|
||||
let mut store = FeatureStore::new();
|
||||
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
|
||||
store.set_flags(response.flags);
|
||||
store.set_flags(response.flags, None).unwrap();
|
||||
|
||||
// This lacks the required properties and cannot be evaluated.
|
||||
let variant =
|
||||
@@ -873,7 +910,7 @@ mod tests {
|
||||
|
||||
let mut store = FeatureStore::new();
|
||||
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
|
||||
store.set_flags(response.flags);
|
||||
store.set_flags(response.flags, None).unwrap();
|
||||
|
||||
// This lacks the required properties and cannot be evaluated.
|
||||
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new());
|
||||
@@ -929,7 +966,7 @@ mod tests {
|
||||
|
||||
let mut store = FeatureStore::new();
|
||||
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
|
||||
store.set_flags(response.flags);
|
||||
store.set_flags(response.flags, None).unwrap();
|
||||
|
||||
// This lacks the required properties and cannot be evaluated.
|
||||
let variant =
|
||||
|
||||
@@ -5,7 +5,7 @@ edition = "2024"
|
||||
license = "MIT/Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.20"
|
||||
base64.workspace = true
|
||||
byteorder.workspace = true
|
||||
bytes.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
use std::fmt::Write;
|
||||
use std::{io, iter, mem, str};
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::{self, Rng};
|
||||
use sha2::digest::FixedOutput;
|
||||
@@ -226,7 +228,7 @@ impl ScramSha256 {
|
||||
|
||||
let (client_key, server_key) = match password {
|
||||
Credentials::Password(password) => {
|
||||
let salt = match base64::decode(parsed.salt) {
|
||||
let salt = match BASE64_STANDARD.decode(parsed.salt) {
|
||||
Ok(salt) => salt,
|
||||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
|
||||
};
|
||||
@@ -255,7 +257,7 @@ impl ScramSha256 {
|
||||
let mut cbind_input = vec![];
|
||||
cbind_input.extend(channel_binding.gs2_header().as_bytes());
|
||||
cbind_input.extend(channel_binding.cbind_data());
|
||||
let cbind_input = base64::encode(&cbind_input);
|
||||
let cbind_input = BASE64_STANDARD.encode(&cbind_input);
|
||||
|
||||
self.message.clear();
|
||||
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
|
||||
@@ -272,7 +274,12 @@ impl ScramSha256 {
|
||||
*proof ^= signature;
|
||||
}
|
||||
|
||||
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
|
||||
write!(
|
||||
&mut self.message,
|
||||
",p={}",
|
||||
BASE64_STANDARD.encode(client_proof)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
self.state = State::Finish {
|
||||
server_key,
|
||||
@@ -306,7 +313,7 @@ impl ScramSha256 {
|
||||
ServerFinalMessage::Verifier(verifier) => verifier,
|
||||
};
|
||||
|
||||
let verifier = match base64::decode(verifier) {
|
||||
let verifier = match BASE64_STANDARD.decode(verifier) {
|
||||
Ok(verifier) => verifier,
|
||||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
|
||||
};
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
//! side. This is good because it ensures the cleartext password won't
|
||||
//! end up in logs pg_stat displays, etc.
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::RngCore;
|
||||
use sha2::digest::FixedOutput;
|
||||
@@ -83,8 +85,8 @@ pub(crate) async fn scram_sha_256_salt(
|
||||
format!(
|
||||
"SCRAM-SHA-256${}:{}${}:{}",
|
||||
SCRAM_DEFAULT_ITERATIONS,
|
||||
base64::encode(salt),
|
||||
base64::encode(stored_key),
|
||||
base64::encode(server_key)
|
||||
BASE64_STANDARD.encode(salt),
|
||||
BASE64_STANDARD.encode(stored_key),
|
||||
BASE64_STANDARD.encode(server_key)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{Error, cancel_query_raw, connect_socket};
|
||||
pub(crate) async fn cancel_query<T>(
|
||||
config: Option<SocketConfig>,
|
||||
ssl_mode: SslMode,
|
||||
mut tls: T,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Result<(), Error>
|
||||
|
||||
@@ -17,7 +17,6 @@ use crate::{Client, Connection, Error};
|
||||
|
||||
/// TLS configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum SslMode {
|
||||
/// Do not use TLS.
|
||||
Disable,
|
||||
@@ -231,7 +230,7 @@ impl Config {
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
pub async fn connect<T>(
|
||||
&self,
|
||||
tls: T,
|
||||
tls: &T,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection};
|
||||
|
||||
pub async fn connect<T>(
|
||||
mut tls: T,
|
||||
tls: &T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
|
||||
@@ -47,7 +47,7 @@ pub trait MakeTlsConnect<S> {
|
||||
/// Creates a new `TlsConnect`or.
|
||||
///
|
||||
/// The domain name is provided for certificate verification and SNI.
|
||||
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
fn make_tls_connect(&self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
}
|
||||
|
||||
/// An asynchronous function wrapping a stream in a TLS session.
|
||||
@@ -85,7 +85,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
|
||||
type TlsConnect = NoTls;
|
||||
type Error = NoTlsError;
|
||||
|
||||
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
Ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{env, io};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
|
||||
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
|
||||
use azure_storage::StorageCredentials;
|
||||
@@ -37,6 +37,7 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
|
||||
use crate::{
|
||||
ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
|
||||
ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
|
||||
Version, VersionKind,
|
||||
};
|
||||
|
||||
pub struct AzureBlobStorage {
|
||||
@@ -405,6 +406,39 @@ impl AzureBlobStorage {
|
||||
pub fn container_name(&self) -> &str {
|
||||
&self.container_name
|
||||
}
|
||||
|
||||
async fn list_versions_with_permit(
|
||||
&self,
|
||||
_permit: &tokio::sync::SemaphorePermit<'_>,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<crate::VersionListing, DownloadError> {
|
||||
let customize_builder = |mut builder: ListBlobsBuilder| {
|
||||
builder = builder.include_versions(true);
|
||||
// We do not return this info back to `VersionListing` yet.
|
||||
builder = builder.include_deleted(true);
|
||||
builder
|
||||
};
|
||||
let kind = RequestKind::ListVersions;
|
||||
|
||||
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
|
||||
prefix,
|
||||
mode,
|
||||
max_keys,
|
||||
cancel,
|
||||
kind,
|
||||
customize_builder
|
||||
));
|
||||
let mut combined: crate::VersionListing =
|
||||
stream.next().await.expect("At least one item required")?;
|
||||
while let Some(list) = stream.next().await {
|
||||
let list = list?;
|
||||
combined.versions.extend(list.versions.into_iter());
|
||||
}
|
||||
Ok(combined)
|
||||
}
|
||||
}
|
||||
|
||||
trait ListingCollector {
|
||||
@@ -488,27 +522,10 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
max_keys: Option<NonZeroU32>,
|
||||
cancel: &CancellationToken,
|
||||
) -> std::result::Result<crate::VersionListing, DownloadError> {
|
||||
let customize_builder = |mut builder: ListBlobsBuilder| {
|
||||
builder = builder.include_versions(true);
|
||||
builder
|
||||
};
|
||||
let kind = RequestKind::ListVersions;
|
||||
|
||||
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
|
||||
prefix,
|
||||
mode,
|
||||
max_keys,
|
||||
cancel,
|
||||
kind,
|
||||
customize_builder
|
||||
));
|
||||
let mut combined: crate::VersionListing =
|
||||
stream.next().await.expect("At least one item required")?;
|
||||
while let Some(list) = stream.next().await {
|
||||
let list = list?;
|
||||
combined.versions.extend(list.versions.into_iter());
|
||||
}
|
||||
Ok(combined)
|
||||
let permit = self.permit(kind, cancel).await?;
|
||||
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn head_object(
|
||||
@@ -803,14 +820,159 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
|
||||
async fn time_travel_recover(
|
||||
&self,
|
||||
_prefix: Option<&RemotePath>,
|
||||
_timestamp: SystemTime,
|
||||
_done_if_after: SystemTime,
|
||||
_cancel: &CancellationToken,
|
||||
prefix: Option<&RemotePath>,
|
||||
timestamp: SystemTime,
|
||||
done_if_after: SystemTime,
|
||||
cancel: &CancellationToken,
|
||||
_complexity_limit: Option<NonZeroU32>,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
// TODO use Azure point in time recovery feature for this
|
||||
// https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
|
||||
Err(TimeTravelError::Unimplemented)
|
||||
let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
|
||||
.to_string()
|
||||
+ "for some specific files. If a file gets deleted but then overwritten and we want to recover "
|
||||
+ "to the time during the file was not present, this functionality will recover the file. Only "
|
||||
+ "use the functionality for services that can tolerate this. For example, recovering a state of the "
|
||||
+ "pageserver tenants.";
|
||||
tracing::error!("{}", msg);
|
||||
|
||||
let kind = RequestKind::TimeTravel;
|
||||
let permit = self.permit(kind, cancel).await?;
|
||||
|
||||
let mode = ListingMode::NoDelimiter;
|
||||
let version_listing = self
|
||||
.list_versions_with_permit(&permit, prefix, mode, None, cancel)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
DownloadError::Other(e) => TimeTravelError::Other(e),
|
||||
DownloadError::Cancelled => TimeTravelError::Cancelled,
|
||||
other => TimeTravelError::Other(other.into()),
|
||||
})?;
|
||||
let versions_and_deletes = version_listing.versions;
|
||||
|
||||
tracing::info!(
|
||||
"Built list for time travel with {} versions and deletions",
|
||||
versions_and_deletes.len()
|
||||
);
|
||||
|
||||
// Work on the list of references instead of the objects directly,
|
||||
// otherwise we get lifetime errors in the sort_by_key call below.
|
||||
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
|
||||
|
||||
versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
|
||||
|
||||
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
|
||||
|
||||
for vd in &versions_and_deletes {
|
||||
let Version { key, .. } = &vd;
|
||||
let version_id = vd.version_id().map(|v| v.0.as_str());
|
||||
if version_id == Some("null") {
|
||||
return Err(TimeTravelError::Other(anyhow!(
|
||||
"Received ListVersions response for key={key} with version_id='null', \
|
||||
indicating either disabled versioning, or legacy objects with null version id values"
|
||||
)));
|
||||
}
|
||||
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
|
||||
|
||||
vds_for_key.entry(key).or_default().push(vd);
|
||||
}
|
||||
|
||||
let warn_threshold = 3;
|
||||
let max_retries = 10;
|
||||
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
|
||||
|
||||
for (key, versions) in vds_for_key {
|
||||
let last_vd = versions.last().unwrap();
|
||||
let key = self.relative_path_to_name(key);
|
||||
if last_vd.last_modified > done_if_after {
|
||||
tracing::debug!("Key {key} has version later than done_if_after, skipping");
|
||||
continue;
|
||||
}
|
||||
// the version we want to restore to.
|
||||
let version_to_restore_to =
|
||||
match versions.binary_search_by_key(×tamp, |tpl| tpl.last_modified) {
|
||||
Ok(v) => v,
|
||||
Err(e) => e,
|
||||
};
|
||||
if version_to_restore_to == versions.len() {
|
||||
tracing::debug!("Key {key} has no changes since timestamp, skipping");
|
||||
continue;
|
||||
}
|
||||
let mut do_delete = false;
|
||||
if version_to_restore_to == 0 {
|
||||
// All versions more recent, so the key didn't exist at the specified time point.
|
||||
tracing::debug!(
|
||||
"All {} versions more recent for {key}, deleting",
|
||||
versions.len()
|
||||
);
|
||||
do_delete = true;
|
||||
} else {
|
||||
match &versions[version_to_restore_to - 1] {
|
||||
Version {
|
||||
kind: VersionKind::Version(version_id),
|
||||
..
|
||||
} => {
|
||||
let source_url = format!(
|
||||
"{}/{}?versionid={}",
|
||||
self.client
|
||||
.url()
|
||||
.map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
|
||||
key,
|
||||
version_id.0
|
||||
);
|
||||
tracing::debug!(
|
||||
"Promoting old version {} for {key} at {}...",
|
||||
version_id.0,
|
||||
source_url
|
||||
);
|
||||
backoff::retry(
|
||||
|| async {
|
||||
let blob_client = self.client.blob_client(key.clone());
|
||||
let op = blob_client.copy(Url::from_str(&source_url).unwrap());
|
||||
tokio::select! {
|
||||
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
|
||||
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
|
||||
}
|
||||
},
|
||||
is_permanent,
|
||||
warn_threshold,
|
||||
max_retries,
|
||||
"copying object version for time_travel_recover",
|
||||
cancel,
|
||||
)
|
||||
.await
|
||||
.ok_or_else(|| TimeTravelError::Cancelled)
|
||||
.and_then(|x| x)?;
|
||||
tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
|
||||
}
|
||||
Version {
|
||||
kind: VersionKind::DeletionMarker,
|
||||
..
|
||||
} => {
|
||||
do_delete = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
if do_delete {
|
||||
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
|
||||
// Key has since been deleted (but there was some history), no need to do anything
|
||||
tracing::debug!("Key {key} already deleted, skipping.");
|
||||
} else {
|
||||
tracing::debug!("Deleting {key}...");
|
||||
|
||||
self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
// delete_oid0 will use TimeoutOrCancel
|
||||
if TimeoutOrCancel::caused_by_cancel(&e) {
|
||||
TimeTravelError::Cancelled
|
||||
} else {
|
||||
TimeTravelError::Other(e)
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -440,6 +440,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
timestamp: SystemTime,
|
||||
done_if_after: SystemTime,
|
||||
cancel: &CancellationToken,
|
||||
complexity_limit: Option<NonZeroU32>,
|
||||
) -> Result<(), TimeTravelError>;
|
||||
}
|
||||
|
||||
@@ -651,22 +652,23 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
timestamp: SystemTime,
|
||||
done_if_after: SystemTime,
|
||||
cancel: &CancellationToken,
|
||||
complexity_limit: Option<NonZeroU32>,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
match self {
|
||||
Self::LocalFs(s) => {
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
|
||||
.await
|
||||
}
|
||||
Self::AwsS3(s) => {
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
|
||||
.await
|
||||
}
|
||||
Self::AzureBlob(s) => {
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
|
||||
.await
|
||||
}
|
||||
Self::Unreliable(s) => {
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
|
||||
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -610,6 +610,7 @@ impl RemoteStorage for LocalFs {
|
||||
_timestamp: SystemTime,
|
||||
_done_if_after: SystemTime,
|
||||
_cancel: &CancellationToken,
|
||||
_complexity_limit: Option<NonZeroU32>,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
Err(TimeTravelError::Unimplemented)
|
||||
}
|
||||
|
||||
@@ -981,22 +981,16 @@ impl RemoteStorage for S3Bucket {
|
||||
timestamp: SystemTime,
|
||||
done_if_after: SystemTime,
|
||||
cancel: &CancellationToken,
|
||||
complexity_limit: Option<NonZeroU32>,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
let kind = RequestKind::TimeTravel;
|
||||
let permit = self.permit(kind, cancel).await?;
|
||||
|
||||
tracing::trace!("Target time: {timestamp:?}, done_if_after {done_if_after:?}");
|
||||
|
||||
// Limit the number of versions deletions, mostly so that we don't
|
||||
// keep requesting forever if the list is too long, as we'd put the
|
||||
// list in RAM.
|
||||
// Building a list of 100k entries that reaches the limit roughly takes
|
||||
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
|
||||
const COMPLEXITY_LIMIT: Option<NonZeroU32> = NonZeroU32::new(100_000);
|
||||
|
||||
let mode = ListingMode::NoDelimiter;
|
||||
let version_listing = self
|
||||
.list_versions_with_permit(&permit, prefix, mode, COMPLEXITY_LIMIT, cancel)
|
||||
.list_versions_with_permit(&permit, prefix, mode, complexity_limit, cancel)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
DownloadError::Other(e) => TimeTravelError::Other(e),
|
||||
@@ -1022,6 +1016,7 @@ impl RemoteStorage for S3Bucket {
|
||||
let Version { key, .. } = &vd;
|
||||
let version_id = vd.version_id().map(|v| v.0.as_str());
|
||||
if version_id == Some("null") {
|
||||
// TODO: check the behavior of using the SDK on a non-versioned container
|
||||
return Err(TimeTravelError::Other(anyhow!(
|
||||
"Received ListVersions response for key={key} with version_id='null', \
|
||||
indicating either disabled versioning, or legacy objects with null version id values"
|
||||
|
||||
@@ -240,11 +240,12 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
timestamp: SystemTime,
|
||||
done_if_after: SystemTime,
|
||||
cancel: &CancellationToken,
|
||||
complexity_limit: Option<NonZeroU32>,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned())))
|
||||
.map_err(TimeTravelError::Other)?;
|
||||
self.inner
|
||||
.time_travel_recover(prefix, timestamp, done_if_after, cancel)
|
||||
.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
|
||||
// No changes after recovery to t2 (no-op)
|
||||
let t_final = time_point().await;
|
||||
ctx.client
|
||||
.time_travel_recover(None, t2, t_final, &cancel)
|
||||
.time_travel_recover(None, t2, t_final, &cancel, None)
|
||||
.await?;
|
||||
let t2_files_recovered = list_files(&ctx.client, &cancel).await?;
|
||||
println!("after recovery to t2: {t2_files_recovered:?}");
|
||||
@@ -173,7 +173,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
|
||||
// after recovery to t1: path1 is back, path2 has the old content
|
||||
let t_final = time_point().await;
|
||||
ctx.client
|
||||
.time_travel_recover(None, t1, t_final, &cancel)
|
||||
.time_travel_recover(None, t1, t_final, &cancel, None)
|
||||
.await?;
|
||||
let t1_files_recovered = list_files(&ctx.client, &cancel).await?;
|
||||
println!("after recovery to t1: {t1_files_recovered:?}");
|
||||
@@ -189,7 +189,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
|
||||
// after recovery to t0: everything is gone except for path1
|
||||
let t_final = time_point().await;
|
||||
ctx.client
|
||||
.time_travel_recover(None, t0, t_final, &cancel)
|
||||
.time_travel_recover(None, t0, t_final, &cancel, None)
|
||||
.await?;
|
||||
let t0_files_recovered = list_files(&ctx.client, &cancel).await?;
|
||||
println!("after recovery to t0: {t0_files_recovered:?}");
|
||||
|
||||
@@ -13,7 +13,7 @@ use utils::pageserver_feedback::PageserverFeedback;
|
||||
use crate::membership::Configuration;
|
||||
use crate::{ServerInfo, Term};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SafekeeperStatus {
|
||||
pub id: NodeId,
|
||||
}
|
||||
|
||||
@@ -176,9 +176,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
let config = RemoteStorageConfig::from_toml_str(&cmd.config_toml_str)?;
|
||||
let storage = remote_storage::GenericRemoteStorage::from_config(&config).await;
|
||||
let cancel = CancellationToken::new();
|
||||
// Complexity limit: as we are running this command locally, we should have a lot of memory available, and we do not
|
||||
// need to limit the number of versions we are going to delete.
|
||||
storage
|
||||
.unwrap()
|
||||
.time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel)
|
||||
.time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel, None)
|
||||
.await?;
|
||||
}
|
||||
Commands::Key(dkc) => dkc.execute(),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_compression::tokio::write::GzipEncoder;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use metrics::core::{AtomicU64, GenericCounter};
|
||||
@@ -167,14 +168,17 @@ impl BasebackupCache {
|
||||
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
|
||||
}
|
||||
|
||||
fn tmp_dir(&self) -> Utf8PathBuf {
|
||||
self.data_dir.join("tmp")
|
||||
}
|
||||
|
||||
fn entry_tmp_path(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
lsn: Lsn,
|
||||
) -> Utf8PathBuf {
|
||||
self.data_dir
|
||||
.join("tmp")
|
||||
self.tmp_dir()
|
||||
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
|
||||
}
|
||||
|
||||
@@ -194,15 +198,18 @@ impl BasebackupCache {
|
||||
Some((tenant_id, timeline_id, lsn))
|
||||
}
|
||||
|
||||
async fn cleanup(&self) -> anyhow::Result<()> {
|
||||
// Cleanup tmp directory.
|
||||
let tmp_dir = self.data_dir.join("tmp");
|
||||
let mut tmp_dir = tokio::fs::read_dir(&tmp_dir).await?;
|
||||
while let Some(dir_entry) = tmp_dir.next_entry().await? {
|
||||
if let Err(e) = tokio::fs::remove_file(dir_entry.path()).await {
|
||||
tracing::warn!("Failed to remove basebackup cache tmp file: {:#}", e);
|
||||
}
|
||||
// Recreate the tmp directory to clear all files in it.
|
||||
async fn clean_tmp_dir(&self) -> anyhow::Result<()> {
|
||||
let tmp_dir = self.tmp_dir();
|
||||
if tmp_dir.exists() {
|
||||
tokio::fs::remove_dir_all(&tmp_dir).await?;
|
||||
}
|
||||
tokio::fs::create_dir_all(&tmp_dir).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cleanup(&self) -> anyhow::Result<()> {
|
||||
self.clean_tmp_dir().await?;
|
||||
|
||||
// Remove outdated entries.
|
||||
let entries_old = self.entries.lock().unwrap().clone();
|
||||
@@ -241,16 +248,14 @@ impl BasebackupCache {
|
||||
}
|
||||
|
||||
async fn on_startup(&self) -> anyhow::Result<()> {
|
||||
// Create data_dir and tmp directory if they do not exist.
|
||||
tokio::fs::create_dir_all(&self.data_dir.join("tmp"))
|
||||
// Create data_dir if it does not exist.
|
||||
tokio::fs::create_dir_all(&self.data_dir)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to create basebackup cache data_dir {:?}: {:?}",
|
||||
self.data_dir,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
.context("Failed to create basebackup cache data directory")?;
|
||||
|
||||
self.clean_tmp_dir()
|
||||
.await
|
||||
.context("Failed to clean tmp directory")?;
|
||||
|
||||
// Read existing entries from the data_dir and add them to in-memory state.
|
||||
let mut entries = HashMap::new();
|
||||
@@ -408,6 +413,19 @@ impl BasebackupCache {
|
||||
.tenant_manager
|
||||
.get_attached_tenant_shard(tenant_shard_id)?;
|
||||
|
||||
let feature_flag = tenant
|
||||
.feature_resolver
|
||||
.evaluate_boolean("enable-basebackup-cache", tenant_shard_id.tenant_id);
|
||||
|
||||
if feature_flag.is_err() {
|
||||
tracing::info!(
|
||||
tenant_id = %tenant_shard_id.tenant_id,
|
||||
"Basebackup cache is disabled for tenant by feature flag, skipping basebackup",
|
||||
);
|
||||
self.prepare_skip_count.inc();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tenant_state = tenant.current_state();
|
||||
if tenant_state != TenantState::Active {
|
||||
anyhow::bail!(
|
||||
@@ -451,6 +469,11 @@ impl BasebackupCache {
|
||||
}
|
||||
|
||||
// Move the tmp file to the final location atomically.
|
||||
// The tmp file is fsynced, so it's guaranteed that we will not have a partial file
|
||||
// in the main directory.
|
||||
// It's not necessary to fsync the inode after renaming, because the worst case is that
|
||||
// the rename operation will be rolled back on the disk failure, the entry will disappear
|
||||
// from the main directory, and the entry access will cause a cache miss.
|
||||
let entry_path = self.entry_path(tenant_shard_id.tenant_id, timeline_id, req_lsn);
|
||||
tokio::fs::rename(&entry_tmp_path, &entry_path).await?;
|
||||
|
||||
@@ -468,16 +491,17 @@ impl BasebackupCache {
|
||||
}
|
||||
|
||||
/// Prepares a basebackup in a temporary file.
|
||||
/// Guarantees that the tmp file is fsynced before returning.
|
||||
async fn prepare_basebackup_tmp(
|
||||
&self,
|
||||
emptry_tmp_path: &Utf8Path,
|
||||
entry_tmp_path: &Utf8Path,
|
||||
timeline: &Arc<Timeline>,
|
||||
req_lsn: Lsn,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = RequestContext::new(TaskKind::BasebackupCache, DownloadBehavior::Download);
|
||||
let ctx = ctx.with_scope_timeline(timeline);
|
||||
|
||||
let file = tokio::fs::File::create(emptry_tmp_path).await?;
|
||||
let file = tokio::fs::File::create(entry_tmp_path).await?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
let mut encoder = GzipEncoder::with_quality(
|
||||
|
||||
@@ -23,6 +23,7 @@ use pageserver::deletion_queue::DeletionQueue;
|
||||
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
|
||||
use pageserver::feature_resolver::FeatureResolver;
|
||||
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
|
||||
use pageserver::page_service::GrpcPageServiceHandler;
|
||||
use pageserver::task_mgr::{
|
||||
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
|
||||
};
|
||||
@@ -572,7 +573,8 @@ fn start_pageserver(
|
||||
tokio::sync::mpsc::unbounded_channel();
|
||||
let deletion_queue_client = deletion_queue.new_client();
|
||||
let background_purges = mgr::BackgroundPurges::default();
|
||||
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
|
||||
|
||||
let tenant_manager = mgr::init(
|
||||
conf,
|
||||
background_purges.clone(),
|
||||
TenantSharedResources {
|
||||
@@ -583,10 +585,10 @@ fn start_pageserver(
|
||||
basebackup_prepare_sender,
|
||||
feature_resolver,
|
||||
},
|
||||
order,
|
||||
shutdown_pageserver.clone(),
|
||||
))?;
|
||||
);
|
||||
let tenant_manager = Arc::new(tenant_manager);
|
||||
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
|
||||
|
||||
let basebackup_cache = BasebackupCache::spawn(
|
||||
BACKGROUND_RUNTIME.handle(),
|
||||
@@ -814,7 +816,7 @@ fn start_pageserver(
|
||||
// necessary?
|
||||
let mut page_service_grpc = None;
|
||||
if let Some(grpc_listener) = grpc_listener {
|
||||
page_service_grpc = Some(page_service::spawn_grpc(
|
||||
page_service_grpc = Some(GrpcPageServiceHandler::spawn(
|
||||
tenant_manager.clone(),
|
||||
grpc_auth,
|
||||
otel_guard.as_ref().map(|g| g.dispatch.clone()),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use pageserver_api::config::NodeMetadata;
|
||||
use posthog_client_lite::{
|
||||
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
|
||||
PostHogFlagFilterPropertyValue,
|
||||
@@ -86,7 +87,35 @@ impl FeatureResolver {
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: add pageserver URL.
|
||||
// TODO: move this to a background task so that we don't block startup in case of slow disk
|
||||
let metadata_path = conf.metadata_path();
|
||||
match std::fs::read_to_string(&metadata_path) {
|
||||
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
|
||||
Ok(metadata) => {
|
||||
properties.insert(
|
||||
"hostname".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String(metadata.http_host),
|
||||
);
|
||||
if let Some(cplane_region) = metadata.other.get("region_id") {
|
||||
if let Some(cplane_region) = cplane_region.as_str() {
|
||||
// This region contains the cell number
|
||||
properties.insert(
|
||||
"neon_region".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String(
|
||||
cplane_region.to_string(),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse metadata.json: {}", e);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to read metadata.json: {}", e);
|
||||
}
|
||||
}
|
||||
Arc::new(properties)
|
||||
};
|
||||
let fake_tenants = {
|
||||
|
||||
@@ -73,6 +73,7 @@ use crate::tenant::remote_timeline_client::{
|
||||
use crate::tenant::secondary::SecondaryController;
|
||||
use crate::tenant::size::ModelInputs;
|
||||
use crate::tenant::storage_layer::{IoConcurrency, LayerAccessStatsReset, LayerName};
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
use crate::tenant::timeline::offload::{OffloadError, offload_timeline};
|
||||
use crate::tenant::timeline::{
|
||||
CompactFlags, CompactOptions, CompactRequest, CompactionError, MarkInvisibleRequest, Timeline,
|
||||
@@ -1451,7 +1452,10 @@ async fn timeline_layer_scan_disposable_keys(
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download)
|
||||
.with_scope_timeline(&timeline);
|
||||
|
||||
let guard = timeline.layers.read().await;
|
||||
let guard = timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let Some(layer) = guard.try_get_from_key(&layer_name.clone().into()) else {
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Layer {tenant_shard_id}/{timeline_id}/{layer_name} not found").into(),
|
||||
|
||||
@@ -1053,6 +1053,15 @@ pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
.expect("Failed to register pageserver_tenant_states_count metric")
|
||||
});
|
||||
|
||||
pub(crate) static TIMELINE_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
register_uint_gauge_vec!(
|
||||
"pageserver_timeline_states_count",
|
||||
"Count of timelines per state",
|
||||
&["state"]
|
||||
)
|
||||
.expect("Failed to register pageserver_timeline_states_count metric")
|
||||
});
|
||||
|
||||
/// A set of broken tenants.
|
||||
///
|
||||
/// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken
|
||||
@@ -3325,6 +3334,8 @@ impl TimelineMetrics {
|
||||
&timeline_id,
|
||||
);
|
||||
|
||||
TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc();
|
||||
|
||||
TimelineMetrics {
|
||||
tenant_id,
|
||||
shard_id,
|
||||
@@ -3479,6 +3490,8 @@ impl TimelineMetrics {
|
||||
return;
|
||||
}
|
||||
|
||||
TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec();
|
||||
|
||||
let tenant_id = &self.tenant_id;
|
||||
let timeline_id = &self.timeline_id;
|
||||
let shard_id = &self.shard_id;
|
||||
|
||||
@@ -169,99 +169,6 @@ pub fn spawn(
|
||||
Listener { cancel, task }
|
||||
}
|
||||
|
||||
/// Spawns a gRPC server for the page service.
|
||||
///
|
||||
/// TODO: move this onto GrpcPageServiceHandler::spawn().
|
||||
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
|
||||
/// need to reimplement the TCP+TLS accept loop ourselves.
|
||||
pub fn spawn_grpc(
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
auth: Option<Arc<SwappableJwtAuth>>,
|
||||
perf_trace_dispatch: Option<Dispatch>,
|
||||
get_vectored_concurrent_io: GetVectoredConcurrentIo,
|
||||
listener: std::net::TcpListener,
|
||||
) -> anyhow::Result<CancellableTask> {
|
||||
let cancel = CancellationToken::new();
|
||||
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
|
||||
.download_behavior(DownloadBehavior::Download)
|
||||
.perf_span_dispatch(perf_trace_dispatch)
|
||||
.detached_child();
|
||||
let gate = Gate::default();
|
||||
|
||||
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
|
||||
// port early during startup.
|
||||
let incoming = {
|
||||
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
|
||||
listener.set_nonblocking(true)?;
|
||||
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?)
|
||||
.with_nodelay(Some(GRPC_TCP_NODELAY))
|
||||
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
|
||||
};
|
||||
|
||||
// Set up the gRPC server.
|
||||
//
|
||||
// TODO: consider tuning window sizes.
|
||||
let mut server = tonic::transport::Server::builder()
|
||||
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
|
||||
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
|
||||
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
|
||||
|
||||
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
|
||||
//
|
||||
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
|
||||
//
|
||||
// * Layers: allow async code, can run code after the service response. However, only has access
|
||||
// to the raw HTTP request/response, not the gRPC types.
|
||||
let page_service_handler = GrpcPageServiceHandler {
|
||||
tenant_manager,
|
||||
ctx,
|
||||
gate_guard: gate.enter().expect("gate was just created"),
|
||||
get_vectored_concurrent_io,
|
||||
};
|
||||
|
||||
let observability_layer = ObservabilityLayer;
|
||||
let mut tenant_interceptor = TenantMetadataInterceptor;
|
||||
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
|
||||
|
||||
let page_service = tower::ServiceBuilder::new()
|
||||
// Create tracing span and record request start time.
|
||||
.layer(observability_layer)
|
||||
// Intercept gRPC requests.
|
||||
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
|
||||
// Extract tenant metadata.
|
||||
req = tenant_interceptor.call(req)?;
|
||||
// Authenticate tenant JWT token.
|
||||
req = auth_interceptor.call(req)?;
|
||||
Ok(req)
|
||||
}))
|
||||
.service(proto::PageServiceServer::new(page_service_handler));
|
||||
let server = server.add_service(page_service);
|
||||
|
||||
// Reflection service for use with e.g. grpcurl.
|
||||
let reflection_service = tonic_reflection::server::Builder::configure()
|
||||
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
|
||||
.build_v1()?;
|
||||
let server = server.add_service(reflection_service);
|
||||
|
||||
// Spawn server task.
|
||||
let task_cancel = cancel.clone();
|
||||
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
|
||||
"grpc listener",
|
||||
async move {
|
||||
let result = server
|
||||
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
|
||||
.await;
|
||||
if result.is_ok() {
|
||||
// TODO: revisit shutdown logic once page service is implemented.
|
||||
gate.close().await;
|
||||
}
|
||||
result
|
||||
},
|
||||
));
|
||||
|
||||
Ok(CancellableTask { task, cancel })
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
pub async fn stop_accepting(self) -> Connections {
|
||||
self.cancel.cancel();
|
||||
@@ -3366,6 +3273,101 @@ pub struct GrpcPageServiceHandler {
|
||||
}
|
||||
|
||||
impl GrpcPageServiceHandler {
|
||||
/// Spawns a gRPC server for the page service.
|
||||
///
|
||||
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
|
||||
/// need to reimplement the TCP+TLS accept loop ourselves.
|
||||
pub fn spawn(
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
auth: Option<Arc<SwappableJwtAuth>>,
|
||||
perf_trace_dispatch: Option<Dispatch>,
|
||||
get_vectored_concurrent_io: GetVectoredConcurrentIo,
|
||||
listener: std::net::TcpListener,
|
||||
) -> anyhow::Result<CancellableTask> {
|
||||
let cancel = CancellationToken::new();
|
||||
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
|
||||
.download_behavior(DownloadBehavior::Download)
|
||||
.perf_span_dispatch(perf_trace_dispatch)
|
||||
.detached_child();
|
||||
let gate = Gate::default();
|
||||
|
||||
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
|
||||
// port early during startup.
|
||||
let incoming = {
|
||||
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
|
||||
listener.set_nonblocking(true)?;
|
||||
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(
|
||||
listener,
|
||||
)?)
|
||||
.with_nodelay(Some(GRPC_TCP_NODELAY))
|
||||
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
|
||||
};
|
||||
|
||||
// Set up the gRPC server.
|
||||
//
|
||||
// TODO: consider tuning window sizes.
|
||||
let mut server = tonic::transport::Server::builder()
|
||||
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
|
||||
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
|
||||
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
|
||||
|
||||
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
|
||||
//
|
||||
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
|
||||
//
|
||||
// * Layers: allow async code, can run code after the service response. However, only has access
|
||||
// to the raw HTTP request/response, not the gRPC types.
|
||||
let page_service_handler = GrpcPageServiceHandler {
|
||||
tenant_manager,
|
||||
ctx,
|
||||
gate_guard: gate.enter().expect("gate was just created"),
|
||||
get_vectored_concurrent_io,
|
||||
};
|
||||
|
||||
let observability_layer = ObservabilityLayer;
|
||||
let mut tenant_interceptor = TenantMetadataInterceptor;
|
||||
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
|
||||
|
||||
let page_service = tower::ServiceBuilder::new()
|
||||
// Create tracing span and record request start time.
|
||||
.layer(observability_layer)
|
||||
// Intercept gRPC requests.
|
||||
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
|
||||
// Extract tenant metadata.
|
||||
req = tenant_interceptor.call(req)?;
|
||||
// Authenticate tenant JWT token.
|
||||
req = auth_interceptor.call(req)?;
|
||||
Ok(req)
|
||||
}))
|
||||
// Run the page service.
|
||||
.service(proto::PageServiceServer::new(page_service_handler));
|
||||
let server = server.add_service(page_service);
|
||||
|
||||
// Reflection service for use with e.g. grpcurl.
|
||||
let reflection_service = tonic_reflection::server::Builder::configure()
|
||||
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
|
||||
.build_v1()?;
|
||||
let server = server.add_service(reflection_service);
|
||||
|
||||
// Spawn server task.
|
||||
let task_cancel = cancel.clone();
|
||||
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
|
||||
"grpc listener",
|
||||
async move {
|
||||
let result = server
|
||||
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
|
||||
.await;
|
||||
if result.is_ok() {
|
||||
// TODO: revisit shutdown logic once page service is implemented.
|
||||
gate.close().await;
|
||||
}
|
||||
result
|
||||
},
|
||||
));
|
||||
|
||||
Ok(CancellableTask { task, cancel })
|
||||
}
|
||||
|
||||
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
|
||||
/// relations and their sizes, as well as SLRU segments and similar data.
|
||||
#[allow(clippy::result_large_err)]
|
||||
|
||||
@@ -51,6 +51,7 @@ use secondary::heatmap::{HeatMapTenant, HeatMapTimeline};
|
||||
use storage_broker::BrokerClientChannel;
|
||||
use timeline::compaction::{CompactionOutcome, GcCompactionQueue};
|
||||
use timeline::import_pgdata::ImportingTimeline;
|
||||
use timeline::layer_manager::LayerManagerLockHolder;
|
||||
use timeline::offload::{OffloadError, offload_timeline};
|
||||
use timeline::{
|
||||
CompactFlags, CompactOptions, CompactionError, PreviousHeatmap, ShutdownMode, import_pgdata,
|
||||
@@ -89,7 +90,8 @@ use crate::l0_flush::L0FlushGlobalState;
|
||||
use crate::metrics::{
|
||||
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
|
||||
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
|
||||
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
|
||||
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC,
|
||||
remove_tenant_metrics,
|
||||
};
|
||||
use crate::task_mgr::TaskKind;
|
||||
use crate::tenant::config::LocationMode;
|
||||
@@ -544,6 +546,28 @@ pub struct OffloadedTimeline {
|
||||
|
||||
/// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it
|
||||
pub deleted_from_ancestor: AtomicBool,
|
||||
|
||||
_metrics_guard: OffloadedTimelineMetricsGuard,
|
||||
}
|
||||
|
||||
/// Increases the offloaded timeline count metric when created, and decreases when dropped.
|
||||
struct OffloadedTimelineMetricsGuard;
|
||||
|
||||
impl OffloadedTimelineMetricsGuard {
|
||||
fn new() -> Self {
|
||||
TIMELINE_STATE_METRIC
|
||||
.with_label_values(&["offloaded"])
|
||||
.inc();
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OffloadedTimelineMetricsGuard {
|
||||
fn drop(&mut self) {
|
||||
TIMELINE_STATE_METRIC
|
||||
.with_label_values(&["offloaded"])
|
||||
.dec();
|
||||
}
|
||||
}
|
||||
|
||||
impl OffloadedTimeline {
|
||||
@@ -576,6 +600,8 @@ impl OffloadedTimeline {
|
||||
|
||||
delete_progress: timeline.delete_progress.clone(),
|
||||
deleted_from_ancestor: AtomicBool::new(false),
|
||||
|
||||
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
|
||||
})
|
||||
}
|
||||
fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self {
|
||||
@@ -595,6 +621,7 @@ impl OffloadedTimeline {
|
||||
archived_at,
|
||||
delete_progress: TimelineDeleteProgress::default(),
|
||||
deleted_from_ancestor: AtomicBool::new(false),
|
||||
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
|
||||
}
|
||||
}
|
||||
fn manifest(&self) -> OffloadedTimelineManifest {
|
||||
@@ -1289,7 +1316,7 @@ impl TenantShard {
|
||||
ancestor.is_some()
|
||||
|| timeline
|
||||
.layers
|
||||
.read()
|
||||
.read(LayerManagerLockHolder::LoadLayerMap)
|
||||
.await
|
||||
.layer_map()
|
||||
.expect(
|
||||
@@ -2617,7 +2644,7 @@ impl TenantShard {
|
||||
}
|
||||
let layer_names = tline
|
||||
.layers
|
||||
.read()
|
||||
.read(LayerManagerLockHolder::Testing)
|
||||
.await
|
||||
.layer_map()
|
||||
.unwrap()
|
||||
@@ -3132,7 +3159,12 @@ impl TenantShard {
|
||||
|
||||
for timeline in &compact {
|
||||
// Collect L0 counts. Can't await while holding lock above.
|
||||
if let Ok(lm) = timeline.layers.read().await.layer_map() {
|
||||
if let Ok(lm) = timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::Compaction)
|
||||
.await
|
||||
.layer_map()
|
||||
{
|
||||
l0_counts.insert(timeline.timeline_id, lm.level0_deltas().len());
|
||||
}
|
||||
}
|
||||
@@ -4874,7 +4906,7 @@ impl TenantShard {
|
||||
}
|
||||
let layer_names = tline
|
||||
.layers
|
||||
.read()
|
||||
.read(LayerManagerLockHolder::Testing)
|
||||
.await
|
||||
.layer_map()
|
||||
.unwrap()
|
||||
@@ -6944,7 +6976,7 @@ mod tests {
|
||||
.await?;
|
||||
make_some_layers(tline.as_ref(), Lsn(0x20), &ctx).await?;
|
||||
|
||||
let layer_map = tline.layers.read().await;
|
||||
let layer_map = tline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
let level0_deltas = layer_map
|
||||
.layer_map()?
|
||||
.level0_deltas()
|
||||
@@ -7180,7 +7212,7 @@ mod tests {
|
||||
let lsn = Lsn(0x10);
|
||||
let inserted = bulk_insert_compact_gc(&tenant, &tline, &ctx, lsn, 50, 10000).await?;
|
||||
|
||||
let guard = tline.layers.read().await;
|
||||
let guard = tline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
let lm = guard.layer_map()?;
|
||||
|
||||
lm.dump(true, &ctx).await?;
|
||||
@@ -8208,12 +8240,23 @@ mod tests {
|
||||
tline.freeze_and_flush().await?; // force create a delta layer
|
||||
}
|
||||
|
||||
let before_num_l0_delta_files =
|
||||
tline.layers.read().await.layer_map()?.level0_deltas().len();
|
||||
let before_num_l0_delta_files = tline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::Testing)
|
||||
.await
|
||||
.layer_map()?
|
||||
.level0_deltas()
|
||||
.len();
|
||||
|
||||
tline.compact(&cancel, EnumSet::default(), &ctx).await?;
|
||||
|
||||
let after_num_l0_delta_files = tline.layers.read().await.layer_map()?.level0_deltas().len();
|
||||
let after_num_l0_delta_files = tline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::Testing)
|
||||
.await
|
||||
.layer_map()?
|
||||
.level0_deltas()
|
||||
.len();
|
||||
|
||||
assert!(
|
||||
after_num_l0_delta_files < before_num_l0_delta_files,
|
||||
|
||||
@@ -61,8 +61,8 @@ pub(crate) struct LocationConf {
|
||||
/// The detailed shard identity. This structure is already scoped within
|
||||
/// a TenantShardId, but we need the full ShardIdentity to enable calculating
|
||||
/// key->shard mappings.
|
||||
// TODO(vlad): Remove this default once all configs have a shard identity on disk.
|
||||
#[serde(default = "ShardIdentity::unsharded")]
|
||||
#[serde(skip_serializing_if = "ShardIdentity::is_unsharded")]
|
||||
pub(crate) shard: ShardIdentity,
|
||||
|
||||
/// The pan-cluster tenant configuration, the same on all locations
|
||||
@@ -149,7 +149,12 @@ impl LocationConf {
|
||||
/// For use when attaching/re-attaching: update the generation stored in this
|
||||
/// structure. If we were in a secondary state, promote to attached (posession
|
||||
/// of a fresh generation implies this).
|
||||
pub(crate) fn attach_in_generation(&mut self, mode: AttachmentMode, generation: Generation) {
|
||||
pub(crate) fn attach_in_generation(
|
||||
&mut self,
|
||||
mode: AttachmentMode,
|
||||
generation: Generation,
|
||||
stripe_size: ShardStripeSize,
|
||||
) {
|
||||
match &mut self.mode {
|
||||
LocationMode::Attached(attach_conf) => {
|
||||
attach_conf.generation = generation;
|
||||
@@ -163,6 +168,8 @@ impl LocationConf {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
self.shard.stripe_size = stripe_size;
|
||||
}
|
||||
|
||||
pub(crate) fn try_from(conf: &'_ models::LocationConfig) -> anyhow::Result<Self> {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
//! Helper functions to upload files to remote storage with a RemoteStorage
|
||||
|
||||
use std::io::{ErrorKind, SeekFrom};
|
||||
use std::num::NonZeroU32;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use anyhow::{Context, bail};
|
||||
@@ -228,11 +229,25 @@ pub(crate) async fn time_travel_recover_tenant(
|
||||
let timelines_path = super::remote_timelines_path(tenant_shard_id);
|
||||
prefixes.push(timelines_path);
|
||||
}
|
||||
|
||||
// Limit the number of versions deletions, mostly so that we don't
|
||||
// keep requesting forever if the list is too long, as we'd put the
|
||||
// list in RAM.
|
||||
// Building a list of 100k entries that reaches the limit roughly takes
|
||||
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
|
||||
const COMPLEXITY_LIMIT: Option<NonZeroU32> = NonZeroU32::new(100_000);
|
||||
|
||||
for prefix in &prefixes {
|
||||
backoff::retry(
|
||||
|| async {
|
||||
storage
|
||||
.time_travel_recover(Some(prefix), timestamp, done_if_after, cancel)
|
||||
.time_travel_recover(
|
||||
Some(prefix),
|
||||
timestamp,
|
||||
done_if_after,
|
||||
cancel,
|
||||
COMPLEXITY_LIMIT,
|
||||
)
|
||||
.await
|
||||
},
|
||||
|e| !matches!(e, TimeTravelError::Other(_)),
|
||||
|
||||
@@ -1635,6 +1635,7 @@ pub(crate) mod test {
|
||||
use crate::tenant::disk_btree::tests::TestDisk;
|
||||
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
|
||||
use crate::tenant::storage_layer::{Layer, ResidentLayer};
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
use crate::tenant::{TenantShard, Timeline};
|
||||
|
||||
/// Construct an index for a fictional delta layer and and then
|
||||
@@ -2002,7 +2003,7 @@ pub(crate) mod test {
|
||||
|
||||
let initdb_layer = timeline
|
||||
.layers
|
||||
.read()
|
||||
.read(crate::tenant::timeline::layer_manager::LayerManagerLockHolder::Testing)
|
||||
.await
|
||||
.likely_resident_layers()
|
||||
.next()
|
||||
@@ -2078,7 +2079,7 @@ pub(crate) mod test {
|
||||
|
||||
let new_layer = timeline
|
||||
.layers
|
||||
.read()
|
||||
.read(LayerManagerLockHolder::Testing)
|
||||
.await
|
||||
.likely_resident_layers()
|
||||
.find(|&x| x != &initdb_layer)
|
||||
|
||||
@@ -10,6 +10,7 @@ use super::*;
|
||||
use crate::context::DownloadBehavior;
|
||||
use crate::tenant::harness::{TenantHarness, test_img};
|
||||
use crate::tenant::storage_layer::{IoConcurrency, LayerVisibilityHint};
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
|
||||
/// Used in tests to advance a future to wanted await point, and not futher.
|
||||
const ADVANCE: std::time::Duration = std::time::Duration::from_secs(3600);
|
||||
@@ -59,7 +60,7 @@ async fn smoke_test() {
|
||||
// there to avoid the timeline being illegally empty
|
||||
let (layer, dummy_layer) = {
|
||||
let mut layers = {
|
||||
let layers = timeline.layers.read().await;
|
||||
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -215,7 +216,7 @@ async fn smoke_test() {
|
||||
|
||||
// Simulate GC removing our test layer.
|
||||
{
|
||||
let mut g = timeline.layers.write().await;
|
||||
let mut g = timeline.layers.write(LayerManagerLockHolder::Testing).await;
|
||||
|
||||
let layers = &[layer];
|
||||
g.open_mut().unwrap().finish_gc_timeline(layers);
|
||||
@@ -261,7 +262,7 @@ async fn evict_and_wait_on_wanted_deleted() {
|
||||
|
||||
let layer = {
|
||||
let mut layers = {
|
||||
let layers = timeline.layers.read().await;
|
||||
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -305,7 +306,7 @@ async fn evict_and_wait_on_wanted_deleted() {
|
||||
// assert that once we remove the `layer` from the layer map and drop our reference,
|
||||
// the deletion of the layer in remote_storage happens.
|
||||
{
|
||||
let mut layers = timeline.layers.write().await;
|
||||
let mut layers = timeline.layers.write(LayerManagerLockHolder::Testing).await;
|
||||
layers.open_mut().unwrap().finish_gc_timeline(&[layer]);
|
||||
}
|
||||
|
||||
@@ -347,7 +348,7 @@ fn read_wins_pending_eviction() {
|
||||
|
||||
let layer = {
|
||||
let mut layers = {
|
||||
let layers = timeline.layers.read().await;
|
||||
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -480,7 +481,7 @@ fn multiple_pending_evictions_scenario(name: &'static str, in_order: bool) {
|
||||
|
||||
let layer = {
|
||||
let mut layers = {
|
||||
let layers = timeline.layers.read().await;
|
||||
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -655,7 +656,7 @@ async fn cancelled_get_or_maybe_download_does_not_cancel_eviction() {
|
||||
|
||||
let layer = {
|
||||
let mut layers = {
|
||||
let layers = timeline.layers.read().await;
|
||||
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -741,7 +742,7 @@ async fn evict_and_wait_does_not_wait_for_download() {
|
||||
|
||||
let layer = {
|
||||
let mut layers = {
|
||||
let layers = timeline.layers.read().await;
|
||||
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -862,7 +863,7 @@ async fn eviction_cancellation_on_drop() {
|
||||
|
||||
let (evicted_layer, not_evicted) = {
|
||||
let mut layers = {
|
||||
let mut guard = timeline.layers.write().await;
|
||||
let mut guard = timeline.layers.write(LayerManagerLockHolder::Testing).await;
|
||||
let layers = guard.likely_resident_layers().cloned().collect::<Vec<_>>();
|
||||
// remove the layers from layermap
|
||||
guard.open_mut().unwrap().finish_gc_timeline(&layers);
|
||||
|
||||
@@ -35,7 +35,11 @@ use fail::fail_point;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use handle::ShardTimelineId;
|
||||
use layer_manager::Shutdown;
|
||||
use layer_manager::{
|
||||
LayerManagerLockHolder, LayerManagerReadGuard, LayerManagerWriteGuard, LockedLayerManager,
|
||||
Shutdown,
|
||||
};
|
||||
|
||||
use offload::OffloadError;
|
||||
use once_cell::sync::Lazy;
|
||||
use pageserver_api::config::tenant_conf_defaults::DEFAULT_PITR_INTERVAL;
|
||||
@@ -82,7 +86,6 @@ use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
|
||||
use self::delete::DeleteTimelineFlow;
|
||||
pub(super) use self::eviction_task::EvictionTaskTenantState;
|
||||
use self::eviction_task::EvictionTaskTimelineState;
|
||||
use self::layer_manager::LayerManager;
|
||||
use self::logical_size::LogicalSize;
|
||||
use self::walreceiver::{WalReceiver, WalReceiverConf};
|
||||
use super::remote_timeline_client::RemoteTimelineClient;
|
||||
@@ -181,13 +184,13 @@ impl std::fmt::Display for ImageLayerCreationMode {
|
||||
|
||||
/// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things.
|
||||
/// Can be removed after all refactors are done.
|
||||
fn drop_rlock<T>(rlock: tokio::sync::RwLockReadGuard<T>) {
|
||||
fn drop_layer_manager_rlock(rlock: LayerManagerReadGuard<'_>) {
|
||||
drop(rlock)
|
||||
}
|
||||
|
||||
/// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things.
|
||||
/// Can be removed after all refactors are done.
|
||||
fn drop_wlock<T>(rlock: tokio::sync::RwLockWriteGuard<'_, T>) {
|
||||
fn drop_layer_manager_wlock(rlock: LayerManagerWriteGuard<'_>) {
|
||||
drop(rlock)
|
||||
}
|
||||
|
||||
@@ -241,7 +244,7 @@ pub struct Timeline {
|
||||
///
|
||||
/// In the future, we'll be able to split up the tuple of LayerMap and `LayerFileManager`,
|
||||
/// so that e.g. on-demand-download/eviction, and layer spreading, can operate just on `LayerFileManager`.
|
||||
pub(crate) layers: tokio::sync::RwLock<LayerManager>,
|
||||
pub(crate) layers: LockedLayerManager,
|
||||
|
||||
last_freeze_at: AtomicLsn,
|
||||
// Atomic would be more appropriate here.
|
||||
@@ -1055,8 +1058,8 @@ pub(crate) enum WaitLsnWaiter<'a> {
|
||||
/// Argument to [`Timeline::shutdown`].
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) enum ShutdownMode {
|
||||
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then
|
||||
/// also to remote storage. This method can easily take multiple seconds for a busy timeline.
|
||||
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can
|
||||
/// take multiple seconds for a busy timeline.
|
||||
///
|
||||
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
|
||||
/// the call to [`Timeline::shutdown`].
|
||||
@@ -1535,7 +1538,10 @@ impl Timeline {
|
||||
/// This method makes no distinction between local and remote layers.
|
||||
/// Hence, the result **does not represent local filesystem usage**.
|
||||
pub(crate) async fn layer_size_sum(&self) -> u64 {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
guard.layer_size_sum()
|
||||
}
|
||||
|
||||
@@ -1845,7 +1851,7 @@ impl Timeline {
|
||||
// time, and this was missed.
|
||||
// if write_guard.is_none() { return; }
|
||||
|
||||
let Ok(layers_guard) = self.layers.try_read() else {
|
||||
let Ok(layers_guard) = self.layers.try_read(LayerManagerLockHolder::TryFreezeLayer) else {
|
||||
// Don't block if the layer lock is busy
|
||||
return;
|
||||
};
|
||||
@@ -2158,7 +2164,7 @@ impl Timeline {
|
||||
if let ShutdownMode::FreezeAndFlush = mode {
|
||||
let do_flush = if let Some((open, frozen)) = self
|
||||
.layers
|
||||
.read()
|
||||
.read(LayerManagerLockHolder::Shutdown)
|
||||
.await
|
||||
.layer_map()
|
||||
.map(|lm| (lm.open_layer.is_some(), lm.frozen_layers.len()))
|
||||
@@ -2262,7 +2268,10 @@ impl Timeline {
|
||||
// Allow any remaining in-memory layers to do cleanup -- until that, they hold the gate
|
||||
// open.
|
||||
let mut write_guard = self.write_lock.lock().await;
|
||||
self.layers.write().await.shutdown(&mut write_guard);
|
||||
self.layers
|
||||
.write(LayerManagerLockHolder::Shutdown)
|
||||
.await
|
||||
.shutdown(&mut write_guard);
|
||||
}
|
||||
|
||||
// Finally wait until any gate-holders are complete.
|
||||
@@ -2365,7 +2374,10 @@ impl Timeline {
|
||||
&self,
|
||||
reset: LayerAccessStatsReset,
|
||||
) -> Result<LayerMapInfo, layer_manager::Shutdown> {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let layer_map = guard.layer_map()?;
|
||||
let mut in_memory_layers = Vec::with_capacity(layer_map.frozen_layers.len() + 1);
|
||||
if let Some(open_layer) = &layer_map.open_layer {
|
||||
@@ -3232,7 +3244,7 @@ impl Timeline {
|
||||
|
||||
/// Initialize with an empty layer map. Used when creating a new timeline.
|
||||
pub(super) fn init_empty_layer_map(&self, start_lsn: Lsn) {
|
||||
let mut layers = self.layers.try_write().expect(
|
||||
let mut layers = self.layers.try_write(LayerManagerLockHolder::Init).expect(
|
||||
"in the context where we call this function, no other task has access to the object",
|
||||
);
|
||||
layers
|
||||
@@ -3252,7 +3264,10 @@ impl Timeline {
|
||||
use init::Decision::*;
|
||||
use init::{Discovered, DismissedLayer};
|
||||
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::LoadLayerMap)
|
||||
.await;
|
||||
|
||||
let timer = self.metrics.load_layer_map_histo.start_timer();
|
||||
|
||||
@@ -3869,7 +3884,10 @@ impl Timeline {
|
||||
&self,
|
||||
layer_name: &LayerName,
|
||||
) -> Result<Option<Layer>, layer_manager::Shutdown> {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let layer = guard
|
||||
.layer_map()?
|
||||
.iter_historic_layers()
|
||||
@@ -3902,7 +3920,10 @@ impl Timeline {
|
||||
return None;
|
||||
}
|
||||
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GenerateHeatmap)
|
||||
.await;
|
||||
|
||||
// Firstly, if there's any heatmap left over from when this location
|
||||
// was a secondary, take that into account. Keep layers that are:
|
||||
@@ -4000,7 +4021,10 @@ impl Timeline {
|
||||
}
|
||||
|
||||
pub(super) async fn generate_unarchival_heatmap(&self, end_lsn: Lsn) -> PreviousHeatmap {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GenerateHeatmap)
|
||||
.await;
|
||||
|
||||
let now = SystemTime::now();
|
||||
let mut heatmap_layers = Vec::default();
|
||||
@@ -4342,7 +4366,7 @@ impl Timeline {
|
||||
query: &VersionedKeySpaceQuery,
|
||||
) -> Result<LayerFringe, GetVectoredError> {
|
||||
let mut fringe = LayerFringe::new();
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::GetPage).await;
|
||||
|
||||
match query {
|
||||
VersionedKeySpaceQuery::Uniform { keyspace, lsn } => {
|
||||
@@ -4445,7 +4469,7 @@ impl Timeline {
|
||||
// required for correctness, but avoids visiting extra layers
|
||||
// which turns out to be a perf bottleneck in some cases.
|
||||
if !unmapped_keyspace.is_empty() {
|
||||
let guard = timeline.layers.read().await;
|
||||
let guard = timeline.layers.read(LayerManagerLockHolder::GetPage).await;
|
||||
guard.update_search_fringe(&unmapped_keyspace, cont_lsn, &mut fringe)?;
|
||||
|
||||
// It's safe to drop the layer map lock after planning the next round of reads.
|
||||
@@ -4555,7 +4579,10 @@ impl Timeline {
|
||||
_guard: &tokio::sync::MutexGuard<'_, Option<TimelineWriterState>>,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Arc<InMemoryLayer>> {
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::GetLayerForWrite)
|
||||
.await;
|
||||
|
||||
let last_record_lsn = self.get_last_record_lsn();
|
||||
ensure!(
|
||||
@@ -4597,7 +4624,10 @@ impl Timeline {
|
||||
write_lock: &mut tokio::sync::MutexGuard<'_, Option<TimelineWriterState>>,
|
||||
) -> Result<u64, FlushLayerError> {
|
||||
let frozen = {
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::TryFreezeLayer)
|
||||
.await;
|
||||
guard
|
||||
.open_mut()?
|
||||
.try_freeze_in_memory_layer(at, &self.last_freeze_at, write_lock, &self.metrics)
|
||||
@@ -4638,7 +4668,12 @@ impl Timeline {
|
||||
ctx: &RequestContext,
|
||||
) {
|
||||
// Subscribe to L0 delta layer updates, for compaction backpressure.
|
||||
let mut watch_l0 = match self.layers.read().await.layer_map() {
|
||||
let mut watch_l0 = match self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::FlushLoop)
|
||||
.await
|
||||
.layer_map()
|
||||
{
|
||||
Ok(lm) => lm.watch_level0_deltas(),
|
||||
Err(Shutdown) => return,
|
||||
};
|
||||
@@ -4675,7 +4710,7 @@ impl Timeline {
|
||||
|
||||
// Fetch the next layer to flush, if any.
|
||||
let (layer, l0_count, frozen_count, frozen_size) = {
|
||||
let layers = self.layers.read().await;
|
||||
let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await;
|
||||
let Ok(lm) = layers.layer_map() else {
|
||||
info!("dropping out of flush loop for timeline shutdown");
|
||||
return;
|
||||
@@ -4971,7 +5006,10 @@ impl Timeline {
|
||||
// in-memory layer from the map now. The flushed layer is stored in
|
||||
// the mapping in `create_delta_layer`.
|
||||
{
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::FlushFrozenLayer)
|
||||
.await;
|
||||
|
||||
guard.open_mut()?.finish_flush_l0_layer(
|
||||
delta_layer_to_add.as_ref(),
|
||||
@@ -5186,7 +5224,7 @@ impl Timeline {
|
||||
async fn time_for_new_image_layer(&self, partition: &KeySpace, lsn: Lsn) -> bool {
|
||||
let threshold = self.get_image_creation_threshold();
|
||||
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
|
||||
let Ok(layers) = guard.layer_map() else {
|
||||
return false;
|
||||
};
|
||||
@@ -5604,7 +5642,7 @@ impl Timeline {
|
||||
if let ImageLayerCreationMode::Force = mode {
|
||||
// When forced to create image layers, we might try and create them where they already
|
||||
// exist. This mode is only used in tests/debug.
|
||||
let layers = self.layers.read().await;
|
||||
let layers = self.layers.read(LayerManagerLockHolder::Compaction).await;
|
||||
if layers.contains_key(&PersistentLayerKey {
|
||||
key_range: img_range.clone(),
|
||||
lsn_range: PersistentLayerDesc::image_layer_lsn_range(lsn),
|
||||
@@ -5729,7 +5767,7 @@ impl Timeline {
|
||||
|
||||
let image_layers = batch_image_writer.finish(self, ctx).await?;
|
||||
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await;
|
||||
|
||||
// FIXME: we could add the images to be uploaded *before* returning from here, but right
|
||||
// now they are being scheduled outside of write lock; current way is inconsistent with
|
||||
@@ -5737,7 +5775,7 @@ impl Timeline {
|
||||
guard
|
||||
.open_mut()?
|
||||
.track_new_image_layers(&image_layers, &self.metrics);
|
||||
drop_wlock(guard);
|
||||
drop_layer_manager_wlock(guard);
|
||||
let duration = timer.stop_and_record();
|
||||
|
||||
// Creating image layers may have caused some previously visible layers to be covered
|
||||
@@ -6107,7 +6145,7 @@ impl Timeline {
|
||||
layers_to_remove: &[Layer],
|
||||
) -> Result<(), CompactionError> {
|
||||
let mut guard = tokio::select! {
|
||||
guard = self.layers.write() => guard,
|
||||
guard = self.layers.write(LayerManagerLockHolder::Compaction) => guard,
|
||||
_ = self.cancel.cancelled() => {
|
||||
return Err(CompactionError::ShuttingDown);
|
||||
}
|
||||
@@ -6156,7 +6194,7 @@ impl Timeline {
|
||||
self.remote_client
|
||||
.schedule_compaction_update(&remove_layers, new_deltas)?;
|
||||
|
||||
drop_wlock(guard);
|
||||
drop_layer_manager_wlock(guard);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -6166,7 +6204,7 @@ impl Timeline {
|
||||
mut replace_layers: Vec<(Layer, ResidentLayer)>,
|
||||
mut drop_layers: Vec<Layer>,
|
||||
) -> Result<(), CompactionError> {
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await;
|
||||
|
||||
// Trim our lists in case our caller (compaction) raced with someone else (GC) removing layers: we want
|
||||
// to avoid double-removing, and avoid rewriting something that was removed.
|
||||
@@ -6517,7 +6555,10 @@ impl Timeline {
|
||||
// 5. newer on-disk image layers cover the layer's whole key range
|
||||
//
|
||||
// TODO holding a write lock is too agressive and avoidable
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::GarbageCollection)
|
||||
.await;
|
||||
let layers = guard.layer_map()?;
|
||||
'outer: for l in layers.iter_historic_layers() {
|
||||
result.layers_total += 1;
|
||||
@@ -6819,7 +6860,10 @@ impl Timeline {
|
||||
use pageserver_api::models::DownloadRemoteLayersTaskState;
|
||||
|
||||
let remaining = {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let Ok(lm) = guard.layer_map() else {
|
||||
// technically here we could look into iterating accessible layers, but downloading
|
||||
// all layers of a shutdown timeline makes no sense regardless.
|
||||
@@ -6925,7 +6969,7 @@ impl Timeline {
|
||||
impl Timeline {
|
||||
/// Returns non-remote layers for eviction.
|
||||
pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::Eviction).await;
|
||||
let mut max_layer_size: Option<u64> = None;
|
||||
|
||||
let resident_layers = guard
|
||||
@@ -7026,7 +7070,7 @@ impl Timeline {
|
||||
let image_layer = Layer::finish_creating(self.conf, self, desc, &path)?;
|
||||
info!("force created image layer {}", image_layer.local_path());
|
||||
{
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
|
||||
guard
|
||||
.open_mut()
|
||||
.unwrap()
|
||||
@@ -7089,7 +7133,7 @@ impl Timeline {
|
||||
let delta_layer = Layer::finish_creating(self.conf, self, desc, &path)?;
|
||||
info!("force created delta layer {}", delta_layer.local_path());
|
||||
{
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
|
||||
guard
|
||||
.open_mut()
|
||||
.unwrap()
|
||||
@@ -7184,7 +7228,7 @@ impl Timeline {
|
||||
|
||||
// Link the layer to the layer map
|
||||
{
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
|
||||
let layer_map = guard.open_mut().unwrap();
|
||||
layer_map.force_insert_in_memory_layer(Arc::new(layer));
|
||||
}
|
||||
@@ -7201,7 +7245,7 @@ impl Timeline {
|
||||
io_concurrency: IoConcurrency,
|
||||
) -> anyhow::Result<Vec<(Key, Bytes)>> {
|
||||
let mut all_data = Vec::new();
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
for layer in guard.layer_map()?.iter_historic_layers() {
|
||||
if !layer.is_delta() && layer.image_layer_lsn() == lsn {
|
||||
let layer = guard.get_from_desc(&layer);
|
||||
@@ -7230,7 +7274,7 @@ impl Timeline {
|
||||
self: &Arc<Timeline>,
|
||||
) -> anyhow::Result<Vec<super::storage_layer::PersistentLayerKey>> {
|
||||
let mut layers = Vec::new();
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
for layer in guard.layer_map()?.iter_historic_layers() {
|
||||
layers.push(layer.key());
|
||||
}
|
||||
@@ -7342,7 +7386,7 @@ impl TimelineWriter<'_> {
|
||||
let l0_count = self
|
||||
.tl
|
||||
.layers
|
||||
.read()
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await
|
||||
.layer_map()?
|
||||
.level0_deltas()
|
||||
@@ -7561,6 +7605,7 @@ mod tests {
|
||||
use crate::tenant::harness::{TenantHarness, test_img};
|
||||
use crate::tenant::layer_map::LayerMap;
|
||||
use crate::tenant::storage_layer::{Layer, LayerName, LayerVisibilityHint};
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
use crate::tenant::timeline::{DeltaLayerTestDesc, EvictionError};
|
||||
use crate::tenant::{PreviousHeatmap, Timeline};
|
||||
|
||||
@@ -7668,7 +7713,7 @@ mod tests {
|
||||
// Evict all the layers and stash the old heatmap in the timeline.
|
||||
// This simulates a migration to a cold secondary location.
|
||||
|
||||
let guard = timeline.layers.read().await;
|
||||
let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
let mut all_layers = Vec::new();
|
||||
let forever = std::time::Duration::from_secs(120);
|
||||
for layer in guard.likely_resident_layers() {
|
||||
@@ -7790,7 +7835,7 @@ mod tests {
|
||||
})));
|
||||
|
||||
// Evict all the layers in the previous heatmap
|
||||
let guard = timeline.layers.read().await;
|
||||
let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await;
|
||||
let forever = std::time::Duration::from_secs(120);
|
||||
for layer in guard.likely_resident_layers() {
|
||||
layer.evict_and_wait(forever).await.unwrap();
|
||||
@@ -7853,7 +7898,10 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn find_some_layer(timeline: &Timeline) -> Layer {
|
||||
let layers = timeline.layers.read().await;
|
||||
let layers = timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let desc = layers
|
||||
.layer_map()
|
||||
.unwrap()
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::ops::Range;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use super::Timeline;
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
pub(crate) struct RangeAnalysis {
|
||||
@@ -24,7 +25,10 @@ impl Timeline {
|
||||
|
||||
let num_of_l0;
|
||||
let all_layer_files = {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
num_of_l0 = guard.layer_map().unwrap().level0_deltas().len();
|
||||
guard.all_persistent_layers()
|
||||
};
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::ops::{Deref, Range};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use super::layer_manager::LayerManager;
|
||||
use super::layer_manager::{LayerManagerLockHolder, LayerManagerReadGuard};
|
||||
use super::{
|
||||
CompactFlags, CompactOptions, CompactionError, CreateImageLayersError, DurationRecorder,
|
||||
GetVectoredError, ImageLayerCreationMode, LastImageLayerCreationStatus, RecordedDuration,
|
||||
@@ -62,7 +62,7 @@ use crate::tenant::storage_layer::{
|
||||
use crate::tenant::tasks::log_compaction_error;
|
||||
use crate::tenant::timeline::{
|
||||
DeltaLayerWriter, ImageLayerCreationOutcome, ImageLayerWriter, IoConcurrency, Layer,
|
||||
ResidentLayer, drop_rlock,
|
||||
ResidentLayer, drop_layer_manager_rlock,
|
||||
};
|
||||
use crate::tenant::{DeltaLayer, MaybeOffloaded};
|
||||
use crate::virtual_file::{MaybeFatalIo, VirtualFile};
|
||||
@@ -314,7 +314,10 @@ impl GcCompactionQueue {
|
||||
.unwrap_or(Lsn::INVALID);
|
||||
|
||||
let layers = {
|
||||
let guard = timeline.layers.read().await;
|
||||
let guard = timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let layer_map = guard.layer_map()?;
|
||||
layer_map.iter_historic_layers().collect_vec()
|
||||
};
|
||||
@@ -408,7 +411,10 @@ impl GcCompactionQueue {
|
||||
timeline: &Arc<Timeline>,
|
||||
lsn: Lsn,
|
||||
) -> Result<u64, CompactionError> {
|
||||
let guard = timeline.layers.read().await;
|
||||
let guard = timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let layer_map = guard.layer_map()?;
|
||||
let layers = layer_map.iter_historic_layers().collect_vec();
|
||||
let mut size = 0;
|
||||
@@ -851,7 +857,7 @@ impl KeyHistoryRetention {
|
||||
}
|
||||
let layer_generation;
|
||||
{
|
||||
let guard = tline.layers.read().await;
|
||||
let guard = tline.layers.read(LayerManagerLockHolder::Compaction).await;
|
||||
if !guard.contains_key(key) {
|
||||
return false;
|
||||
}
|
||||
@@ -1282,7 +1288,10 @@ impl Timeline {
|
||||
// We do the repartition on the L0-L1 boundary. All data below the boundary
|
||||
// are compacted by L0 with low read amplification, thus making the `repartition`
|
||||
// function run fast.
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
guard
|
||||
.all_persistent_layers()
|
||||
.iter()
|
||||
@@ -1461,7 +1470,7 @@ impl Timeline {
|
||||
let latest_gc_cutoff = self.get_applied_gc_cutoff_lsn();
|
||||
let pitr_cutoff = self.gc_info.read().unwrap().cutoffs.time;
|
||||
|
||||
let layers = self.layers.read().await;
|
||||
let layers = self.layers.read(LayerManagerLockHolder::Compaction).await;
|
||||
let layers_iter = layers.layer_map()?.iter_historic_layers();
|
||||
let (layers_total, mut layers_checked) = (layers_iter.len(), 0);
|
||||
for layer_desc in layers_iter {
|
||||
@@ -1722,7 +1731,10 @@ impl Timeline {
|
||||
// are implicitly left visible, because LayerVisibilityHint's default is Visible, and we never modify it here.
|
||||
// Note that L0 deltas _can_ be covered by image layers, but we consider them 'visible' because we anticipate that
|
||||
// they will be subject to L0->L1 compaction in the near future.
|
||||
let layer_manager = self.layers.read().await;
|
||||
let layer_manager = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GetLayerMapInfo)
|
||||
.await;
|
||||
let layer_map = layer_manager.layer_map()?;
|
||||
|
||||
let readable_points = {
|
||||
@@ -1775,7 +1787,7 @@ impl Timeline {
|
||||
};
|
||||
|
||||
let begin = tokio::time::Instant::now();
|
||||
let phase1_layers_locked = self.layers.read().await;
|
||||
let phase1_layers_locked = self.layers.read(LayerManagerLockHolder::Compaction).await;
|
||||
let now = tokio::time::Instant::now();
|
||||
stats.read_lock_acquisition_micros =
|
||||
DurationRecorder::Recorded(RecordedDuration(now - begin), now);
|
||||
@@ -1803,7 +1815,7 @@ impl Timeline {
|
||||
/// Level0 files first phase of compaction, explained in the [`Self::compact_legacy`] comment.
|
||||
async fn compact_level0_phase1<'a>(
|
||||
self: &'a Arc<Self>,
|
||||
guard: tokio::sync::RwLockReadGuard<'a, LayerManager>,
|
||||
guard: LayerManagerReadGuard<'a>,
|
||||
mut stats: CompactLevel0Phase1StatsBuilder,
|
||||
target_file_size: u64,
|
||||
force_compaction_ignore_threshold: bool,
|
||||
@@ -2029,7 +2041,7 @@ impl Timeline {
|
||||
holes
|
||||
};
|
||||
stats.read_lock_held_compute_holes_micros = stats.read_lock_held_key_sort_micros.till_now();
|
||||
drop_rlock(guard);
|
||||
drop_layer_manager_rlock(guard);
|
||||
|
||||
if self.cancel.is_cancelled() {
|
||||
return Err(CompactionError::ShuttingDown);
|
||||
@@ -2469,7 +2481,7 @@ impl Timeline {
|
||||
|
||||
// Find the top of the historical layers
|
||||
let end_lsn = {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
|
||||
let layers = guard.layer_map()?;
|
||||
|
||||
let l0_deltas = layers.level0_deltas();
|
||||
@@ -3008,7 +3020,7 @@ impl Timeline {
|
||||
}
|
||||
split_key_ranges.sort();
|
||||
let all_layers = {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
|
||||
let layer_map = guard.layer_map()?;
|
||||
layer_map.iter_historic_layers().collect_vec()
|
||||
};
|
||||
@@ -3112,12 +3124,12 @@ impl Timeline {
|
||||
.await?;
|
||||
let jobs_len = jobs.len();
|
||||
for (idx, job) in jobs.into_iter().enumerate() {
|
||||
info!(
|
||||
"running enhanced gc bottom-most compaction, sub-compaction {}/{}",
|
||||
idx + 1,
|
||||
jobs_len
|
||||
);
|
||||
let sub_compaction_progress = format!("{}/{}", idx + 1, jobs_len);
|
||||
self.compact_with_gc_inner(cancel, job, ctx, yield_for_l0)
|
||||
.instrument(info_span!(
|
||||
"sub_compaction",
|
||||
sub_compaction_progress = sub_compaction_progress
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
if jobs_len == 0 {
|
||||
@@ -3185,7 +3197,10 @@ impl Timeline {
|
||||
// 1. If a layer is in the selection, all layers below it are in the selection.
|
||||
// 2. Inferred from (1), for each key in the layer selection, the value can be reconstructed only with the layers in the layer selection.
|
||||
let job_desc = {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GarbageCollection)
|
||||
.await;
|
||||
let layers = guard.layer_map()?;
|
||||
let gc_info = self.gc_info.read().unwrap();
|
||||
let mut retain_lsns_below_horizon = Vec::new();
|
||||
@@ -3956,7 +3971,10 @@ impl Timeline {
|
||||
|
||||
// First, do a sanity check to ensure the newly-created layer map does not contain overlaps.
|
||||
let all_layers = {
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::GarbageCollection)
|
||||
.await;
|
||||
let layer_map = guard.layer_map()?;
|
||||
layer_map.iter_historic_layers().collect_vec()
|
||||
};
|
||||
@@ -4020,7 +4038,10 @@ impl Timeline {
|
||||
let update_guard = self.gc_compaction_layer_update_lock.write().await;
|
||||
// Acquiring the update guard ensures current read operations end and new read operations are blocked.
|
||||
// TODO: can we use `latest_gc_cutoff` Rcu to achieve the same effect?
|
||||
let mut guard = self.layers.write().await;
|
||||
let mut guard = self
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::GarbageCollection)
|
||||
.await;
|
||||
guard
|
||||
.open_mut()?
|
||||
.finish_gc_compaction(&layer_selection, &compact_to, &self.metrics);
|
||||
@@ -4088,7 +4109,11 @@ impl TimelineAdaptor {
|
||||
|
||||
pub async fn flush_updates(&mut self) -> Result<(), CompactionError> {
|
||||
let layers_to_delete = {
|
||||
let guard = self.timeline.layers.read().await;
|
||||
let guard = self
|
||||
.timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::Compaction)
|
||||
.await;
|
||||
self.layers_to_delete
|
||||
.iter()
|
||||
.map(|x| guard.get_from_desc(x))
|
||||
@@ -4133,7 +4158,11 @@ impl CompactionJobExecutor for TimelineAdaptor {
|
||||
) -> anyhow::Result<Vec<OwnArc<PersistentLayerDesc>>> {
|
||||
self.flush_updates().await?;
|
||||
|
||||
let guard = self.timeline.layers.read().await;
|
||||
let guard = self
|
||||
.timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::Compaction)
|
||||
.await;
|
||||
let layer_map = guard.layer_map()?;
|
||||
|
||||
let result = layer_map
|
||||
@@ -4172,7 +4201,11 @@ impl CompactionJobExecutor for TimelineAdaptor {
|
||||
// this is a lot more complex than a simple downcast...
|
||||
if layer.is_delta() {
|
||||
let l = {
|
||||
let guard = self.timeline.layers.read().await;
|
||||
let guard = self
|
||||
.timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::Compaction)
|
||||
.await;
|
||||
guard.get_from_desc(layer)
|
||||
};
|
||||
let result = l.download_and_keep_resident(ctx).await?;
|
||||
|
||||
@@ -19,7 +19,7 @@ use utils::id::TimelineId;
|
||||
use utils::lsn::Lsn;
|
||||
use utils::sync::gate::GateError;
|
||||
|
||||
use super::layer_manager::LayerManager;
|
||||
use super::layer_manager::{LayerManager, LayerManagerLockHolder};
|
||||
use super::{FlushLayerError, Timeline};
|
||||
use crate::context::{DownloadBehavior, RequestContext};
|
||||
use crate::task_mgr::TaskKind;
|
||||
@@ -199,7 +199,10 @@ pub(crate) async fn generate_tombstone_image_layer(
|
||||
let image_lsn = ancestor_lsn;
|
||||
|
||||
{
|
||||
let layers = detached.layers.read().await;
|
||||
let layers = detached
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::DetachAncestor)
|
||||
.await;
|
||||
for layer in layers.all_persistent_layers() {
|
||||
if !layer.is_delta
|
||||
&& layer.lsn_range.start == image_lsn
|
||||
@@ -423,7 +426,7 @@ pub(super) async fn prepare(
|
||||
// we do not need to start from our layers, because they can only be layers that come
|
||||
// *after* ancestor_lsn
|
||||
let layers = tokio::select! {
|
||||
guard = ancestor.layers.read() => guard,
|
||||
guard = ancestor.layers.read(LayerManagerLockHolder::DetachAncestor) => guard,
|
||||
_ = detached.cancel.cancelled() => {
|
||||
return Err(ShuttingDown);
|
||||
}
|
||||
@@ -869,7 +872,12 @@ async fn remote_copy(
|
||||
|
||||
// Double check that the file is orphan (probably from an earlier attempt), then delete it
|
||||
let key = file_name.clone().into();
|
||||
if adoptee.layers.read().await.contains_key(&key) {
|
||||
if adoptee
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::DetachAncestor)
|
||||
.await
|
||||
.contains_key(&key)
|
||||
{
|
||||
// We are supposed to filter out such cases before coming to this function
|
||||
return Err(Error::Prepare(anyhow::anyhow!(
|
||||
"layer file {file_name} already present and inside layer map"
|
||||
|
||||
@@ -33,6 +33,7 @@ use crate::tenant::size::CalculateSyntheticSizeError;
|
||||
use crate::tenant::storage_layer::LayerVisibilityHint;
|
||||
use crate::tenant::tasks::{BackgroundLoopKind, BackgroundLoopSemaphorePermit, sleep_random};
|
||||
use crate::tenant::timeline::EvictionError;
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
use crate::tenant::{LogicalSizeCalculationCause, TenantShard};
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -208,7 +209,7 @@ impl Timeline {
|
||||
|
||||
let mut js = tokio::task::JoinSet::new();
|
||||
{
|
||||
let guard = self.layers.read().await;
|
||||
let guard = self.layers.read(LayerManagerLockHolder::Eviction).await;
|
||||
|
||||
guard
|
||||
.likely_resident_layers()
|
||||
|
||||
@@ -15,6 +15,7 @@ use super::{Timeline, TimelineDeleteProgress};
|
||||
use crate::context::RequestContext;
|
||||
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
|
||||
use crate::tenant::metadata::TimelineMetadata;
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
|
||||
mod flow;
|
||||
mod importbucket_client;
|
||||
@@ -163,7 +164,10 @@ async fn prepare_import(
|
||||
info!("wipe the slate clean");
|
||||
{
|
||||
// TODO: do we need to hold GC lock for this?
|
||||
let mut guard = timeline.layers.write().await;
|
||||
let mut guard = timeline
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::ImportPgData)
|
||||
.await;
|
||||
assert!(
|
||||
guard.layer_map()?.open_layer.is_none(),
|
||||
"while importing, there should be no in-memory layer" // this just seems like a good place to assert it
|
||||
|
||||
@@ -56,6 +56,7 @@ use crate::pgdatadir_mapping::{
|
||||
};
|
||||
use crate::task_mgr::TaskKind;
|
||||
use crate::tenant::storage_layer::{AsLayerDesc, ImageLayerWriter, Layer};
|
||||
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
|
||||
|
||||
pub async fn run(
|
||||
timeline: Arc<Timeline>,
|
||||
@@ -984,7 +985,10 @@ impl ChunkProcessingJob {
|
||||
let (desc, path) = writer.finish(ctx).await?;
|
||||
|
||||
{
|
||||
let guard = timeline.layers.read().await;
|
||||
let guard = timeline
|
||||
.layers
|
||||
.read(LayerManagerLockHolder::ImportPgData)
|
||||
.await;
|
||||
let existing_layer = guard.try_get_from_key(&desc.key());
|
||||
if let Some(layer) = existing_layer {
|
||||
if layer.metadata().generation == timeline.generation {
|
||||
@@ -1007,7 +1011,10 @@ impl ChunkProcessingJob {
|
||||
// certain that the existing layer is identical to the new one, so in that case
|
||||
// we replace the old layer with the one we just generated.
|
||||
|
||||
let mut guard = timeline.layers.write().await;
|
||||
let mut guard = timeline
|
||||
.layers
|
||||
.write(LayerManagerLockHolder::ImportPgData)
|
||||
.await;
|
||||
|
||||
let existing_layer = guard
|
||||
.try_get_from_key(&resident_layer.layer_desc().key())
|
||||
@@ -1036,7 +1043,7 @@ impl ChunkProcessingJob {
|
||||
}
|
||||
}
|
||||
|
||||
crate::tenant::timeline::drop_wlock(guard);
|
||||
crate::tenant::timeline::drop_layer_manager_wlock(guard);
|
||||
|
||||
timeline
|
||||
.remote_client
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
use std::mem::ManuallyDrop;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, bail, ensure};
|
||||
use itertools::Itertools;
|
||||
@@ -20,6 +23,155 @@ use crate::tenant::storage_layer::{
|
||||
PersistentLayerKey, ReadableLayerWeak, ResidentLayer,
|
||||
};
|
||||
|
||||
/// Warn if the lock was held for longer than this threshold.
|
||||
/// It's very generous and we should bring this value down over time.
|
||||
const LAYER_MANAGER_LOCK_WARN_THRESHOLD: Duration = Duration::from_secs(5);
|
||||
const LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Describes the operation that is holding the layer manager lock
|
||||
#[derive(Debug, Clone, Copy, strum_macros::Display)]
|
||||
#[strum(serialize_all = "kebab_case")]
|
||||
pub(crate) enum LayerManagerLockHolder {
|
||||
GetLayerMapInfo,
|
||||
GenerateHeatmap,
|
||||
GetPage,
|
||||
Init,
|
||||
LoadLayerMap,
|
||||
GetLayerForWrite,
|
||||
TryFreezeLayer,
|
||||
FlushFrozenLayer,
|
||||
FlushLoop,
|
||||
Compaction,
|
||||
GarbageCollection,
|
||||
Shutdown,
|
||||
ImportPgData,
|
||||
DetachAncestor,
|
||||
Eviction,
|
||||
#[cfg(test)]
|
||||
Testing,
|
||||
}
|
||||
|
||||
/// Wrapper for the layer manager that tracks the amount of time during which
|
||||
/// it was held under read or write lock
|
||||
#[derive(Default)]
|
||||
pub(crate) struct LockedLayerManager {
|
||||
locked: tokio::sync::RwLock<LayerManager>,
|
||||
}
|
||||
|
||||
pub(crate) struct LayerManagerReadGuard<'a> {
|
||||
guard: ManuallyDrop<tokio::sync::RwLockReadGuard<'a, LayerManager>>,
|
||||
acquired_at: std::time::Instant,
|
||||
holder: LayerManagerLockHolder,
|
||||
}
|
||||
|
||||
pub(crate) struct LayerManagerWriteGuard<'a> {
|
||||
guard: ManuallyDrop<tokio::sync::RwLockWriteGuard<'a, LayerManager>>,
|
||||
acquired_at: std::time::Instant,
|
||||
holder: LayerManagerLockHolder,
|
||||
}
|
||||
|
||||
impl Drop for LayerManagerReadGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
// Drop the lock first, before potentially warning if it was held for too long.
|
||||
// SAFETY: ManuallyDrop in Drop implementation
|
||||
unsafe { ManuallyDrop::drop(&mut self.guard) };
|
||||
|
||||
let held_for = self.acquired_at.elapsed();
|
||||
if held_for >= LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD {
|
||||
tracing::warn!(
|
||||
holder=%self.holder,
|
||||
"Layer manager read lock held for {}s",
|
||||
held_for.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LayerManagerWriteGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
// Drop the lock first, before potentially warning if it was held for too long.
|
||||
// SAFETY: ManuallyDrop in Drop implementation
|
||||
unsafe { ManuallyDrop::drop(&mut self.guard) };
|
||||
|
||||
let held_for = self.acquired_at.elapsed();
|
||||
if held_for >= LAYER_MANAGER_LOCK_WARN_THRESHOLD {
|
||||
tracing::warn!(
|
||||
holder=%self.holder,
|
||||
"Layer manager write lock held for {}s",
|
||||
held_for.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for LayerManagerReadGuard<'_> {
|
||||
type Target = LayerManager;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.guard.deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for LayerManagerWriteGuard<'_> {
|
||||
type Target = LayerManager;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.guard.deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for LayerManagerWriteGuard<'_> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.guard.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl LockedLayerManager {
|
||||
pub(crate) async fn read(&self, holder: LayerManagerLockHolder) -> LayerManagerReadGuard {
|
||||
let guard = ManuallyDrop::new(self.locked.read().await);
|
||||
LayerManagerReadGuard {
|
||||
guard,
|
||||
acquired_at: std::time::Instant::now(),
|
||||
holder,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn try_read(
|
||||
&self,
|
||||
holder: LayerManagerLockHolder,
|
||||
) -> Result<LayerManagerReadGuard, tokio::sync::TryLockError> {
|
||||
let guard = ManuallyDrop::new(self.locked.try_read()?);
|
||||
|
||||
Ok(LayerManagerReadGuard {
|
||||
guard,
|
||||
acquired_at: std::time::Instant::now(),
|
||||
holder,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn write(&self, holder: LayerManagerLockHolder) -> LayerManagerWriteGuard {
|
||||
let guard = ManuallyDrop::new(self.locked.write().await);
|
||||
LayerManagerWriteGuard {
|
||||
guard,
|
||||
acquired_at: std::time::Instant::now(),
|
||||
holder,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn try_write(
|
||||
&self,
|
||||
holder: LayerManagerLockHolder,
|
||||
) -> Result<LayerManagerWriteGuard, tokio::sync::TryLockError> {
|
||||
let guard = ManuallyDrop::new(self.locked.try_write()?);
|
||||
|
||||
Ok(LayerManagerWriteGuard {
|
||||
guard,
|
||||
acquired_at: std::time::Instant::now(),
|
||||
holder,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides semantic APIs to manipulate the layer map.
|
||||
pub(crate) enum LayerManager {
|
||||
/// Open as in not shutdown layer manager; we still have in-memory layers and we can manipulate
|
||||
|
||||
@@ -1092,13 +1092,15 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
|
||||
MyPState->ring_last <= ring_index);
|
||||
}
|
||||
|
||||
/* internal version. Returns the ring index */
|
||||
/* Internal version. Returns the ring index of the last block (result of this function is used only
|
||||
* when nblocks==1)
|
||||
*/
|
||||
static uint64
|
||||
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
|
||||
BlockNumber nblocks, const bits8 *mask,
|
||||
bool is_prefetch)
|
||||
{
|
||||
uint64 min_ring_index;
|
||||
uint64 last_ring_index;
|
||||
PrefetchRequest hashkey;
|
||||
#ifdef USE_ASSERT_CHECKING
|
||||
bool any_hits = false;
|
||||
@@ -1122,13 +1124,12 @@ Retry:
|
||||
MyPState->ring_unused - MyPState->ring_receive;
|
||||
MyNeonCounters->getpage_prefetches_buffered =
|
||||
MyPState->n_responses_buffered;
|
||||
last_ring_index = UINT64_MAX;
|
||||
|
||||
min_ring_index = UINT64_MAX;
|
||||
for (int i = 0; i < nblocks; i++)
|
||||
{
|
||||
PrefetchRequest *slot = NULL;
|
||||
PrfHashEntry *entry = NULL;
|
||||
uint64 ring_index;
|
||||
neon_request_lsns *lsns;
|
||||
|
||||
if (PointerIsValid(mask) && BITMAP_ISSET(mask, i))
|
||||
@@ -1152,12 +1153,12 @@ Retry:
|
||||
if (entry != NULL)
|
||||
{
|
||||
slot = entry->slot;
|
||||
ring_index = slot->my_ring_index;
|
||||
Assert(slot == GetPrfSlot(ring_index));
|
||||
last_ring_index = slot->my_ring_index;
|
||||
Assert(slot == GetPrfSlot(last_ring_index));
|
||||
|
||||
Assert(slot->status != PRFS_UNUSED);
|
||||
Assert(MyPState->ring_last <= ring_index &&
|
||||
ring_index < MyPState->ring_unused);
|
||||
Assert(MyPState->ring_last <= last_ring_index &&
|
||||
last_ring_index < MyPState->ring_unused);
|
||||
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
|
||||
|
||||
/*
|
||||
@@ -1169,9 +1170,9 @@ Retry:
|
||||
if (!neon_prefetch_response_usable(lsns, slot))
|
||||
{
|
||||
/* Wait for the old request to finish and discard it */
|
||||
if (!prefetch_wait_for(ring_index))
|
||||
if (!prefetch_wait_for(last_ring_index))
|
||||
goto Retry;
|
||||
prefetch_set_unused(ring_index);
|
||||
prefetch_set_unused(last_ring_index);
|
||||
entry = NULL;
|
||||
slot = NULL;
|
||||
pgBufferUsage.prefetch.expired += 1;
|
||||
@@ -1188,13 +1189,12 @@ Retry:
|
||||
*/
|
||||
if (slot->status == PRFS_TAG_REMAINS)
|
||||
{
|
||||
prefetch_set_unused(ring_index);
|
||||
prefetch_set_unused(last_ring_index);
|
||||
entry = NULL;
|
||||
slot = NULL;
|
||||
}
|
||||
else
|
||||
{
|
||||
min_ring_index = Min(min_ring_index, ring_index);
|
||||
/* The buffered request is good enough, return that index */
|
||||
if (is_prefetch)
|
||||
pgBufferUsage.prefetch.duplicates++;
|
||||
@@ -1283,12 +1283,12 @@ Retry:
|
||||
* The next buffer pointed to by `ring_unused` is now definitely empty, so
|
||||
* we can insert the new request to it.
|
||||
*/
|
||||
ring_index = MyPState->ring_unused;
|
||||
last_ring_index = MyPState->ring_unused;
|
||||
|
||||
Assert(MyPState->ring_last <= ring_index &&
|
||||
ring_index <= MyPState->ring_unused);
|
||||
Assert(MyPState->ring_last <= last_ring_index &&
|
||||
last_ring_index <= MyPState->ring_unused);
|
||||
|
||||
slot = GetPrfSlotNoCheck(ring_index);
|
||||
slot = GetPrfSlotNoCheck(last_ring_index);
|
||||
|
||||
Assert(slot->status == PRFS_UNUSED);
|
||||
|
||||
@@ -1298,11 +1298,9 @@ Retry:
|
||||
*/
|
||||
slot->buftag = hashkey.buftag;
|
||||
slot->shard_no = get_shard_number(&tag);
|
||||
slot->my_ring_index = ring_index;
|
||||
slot->my_ring_index = last_ring_index;
|
||||
slot->flags = 0;
|
||||
|
||||
min_ring_index = Min(min_ring_index, ring_index);
|
||||
|
||||
if (is_prefetch)
|
||||
MyNeonCounters->getpage_prefetch_requests_total++;
|
||||
else
|
||||
@@ -1315,11 +1313,12 @@ Retry:
|
||||
MyPState->ring_unused - MyPState->ring_receive;
|
||||
|
||||
Assert(any_hits);
|
||||
Assert(last_ring_index != UINT64_MAX);
|
||||
|
||||
Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED ||
|
||||
GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED);
|
||||
Assert(MyPState->ring_last <= min_ring_index &&
|
||||
min_ring_index < MyPState->ring_unused);
|
||||
Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED ||
|
||||
GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED);
|
||||
Assert(MyPState->ring_last <= last_ring_index &&
|
||||
last_ring_index < MyPState->ring_unused);
|
||||
|
||||
if (flush_every_n_requests > 0 &&
|
||||
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
|
||||
@@ -1335,7 +1334,7 @@ Retry:
|
||||
MyPState->ring_flush = MyPState->ring_unused;
|
||||
}
|
||||
|
||||
return min_ring_index;
|
||||
return last_ring_index;
|
||||
}
|
||||
|
||||
static bool
|
||||
|
||||
@@ -2,6 +2,6 @@ DROP FUNCTION IF EXISTS get_prewarm_info(out total_pages integer, out prewarmed_
|
||||
|
||||
DROP FUNCTION IF EXISTS get_local_cache_state(max_chunks integer);
|
||||
|
||||
DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer default 1);
|
||||
DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer);
|
||||
|
||||
|
||||
|
||||
@@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf
|
||||
wp->propTermStartLsn = sk->voteResponse.flushLsn;
|
||||
wp->donor = sk;
|
||||
}
|
||||
wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn);
|
||||
wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn);
|
||||
|
||||
if (n_votes > 0)
|
||||
appendStringInfoString(s, ", ");
|
||||
|
||||
10
poetry.lock
generated
10
poetry.lock
generated
@@ -3051,19 +3051,19 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
version = "2.32.4"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||
{file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"},
|
||||
{file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = ">=2017.4.17"
|
||||
charset-normalizer = ">=2,<4"
|
||||
charset_normalizer = ">=2,<4"
|
||||
idna = ">=2.5,<4"
|
||||
urllib3 = ">=1.21.1,<3"
|
||||
|
||||
@@ -3846,4 +3846,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "7ab1e7b975af34b3271b7c6018fa22a261d3f73c7c0a0403b6b2bb86b5fbd36e"
|
||||
content-hash = "bd93313f110110aa53b24a3ed47ba2d7f60e2c658a79cdff7320fed1bb1b57b5"
|
||||
|
||||
@@ -18,11 +18,6 @@ pub(super) async fn authenticate(
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
debug!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
debug!("auth endpoint chooses SCRAM");
|
||||
|
||||
|
||||
@@ -6,18 +6,17 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use super::ComputeCredentialKeys;
|
||||
use crate::auth::IpPattern;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::cache::Cached;
|
||||
use crate::compute::AuthInfo;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::cplane_proxy_v1;
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::pglb::connect_compute::ComputeConnectBackend;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
use crate::stream::PqStream;
|
||||
use crate::types::RoleName;
|
||||
use crate::{auth, compute, waiters};
|
||||
@@ -98,15 +97,11 @@ impl ConsoleRedirectBackend {
|
||||
ctx: &RequestContext,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(
|
||||
ConsoleRedirectNodeInfo,
|
||||
ComputeUserInfo,
|
||||
Option<Vec<IpPattern>>,
|
||||
)> {
|
||||
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
authenticate(ctx, auth_config, &self.console_uri, client)
|
||||
.await
|
||||
.map(|(node_info, user_info, ip_allowlist)| {
|
||||
(ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist)
|
||||
.map(|(node_info, auth_info, user_info)| {
|
||||
(ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -114,17 +109,13 @@ impl ConsoleRedirectBackend {
|
||||
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
|
||||
|
||||
#[async_trait]
|
||||
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
|
||||
impl WakeComputeBackend for ConsoleRedirectNodeInfo {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
Ok(Cached::new_uncached(self.0.clone()))
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&ComputeCredentialKeys::None
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
@@ -132,7 +123,7 @@ async fn authenticate(
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(NodeInfo, ComputeUserInfo, Option<Vec<IpPattern>>)> {
|
||||
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
|
||||
|
||||
// registering waiter can fail if we get unlucky with rng.
|
||||
@@ -192,10 +183,24 @@ async fn authenticate(
|
||||
|
||||
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
|
||||
config.dbname(&db_info.dbname).user(&db_info.user);
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
let ssl_mode = if db_info.host.contains("--") {
|
||||
// we need TLS connection with SNI info to properly route it
|
||||
SslMode::Require
|
||||
} else {
|
||||
SslMode::Disable
|
||||
};
|
||||
|
||||
let conn_info = compute::ConnectInfo {
|
||||
host: db_info.host.into(),
|
||||
port: db_info.port,
|
||||
ssl_mode,
|
||||
host_addr: None,
|
||||
};
|
||||
let auth_info =
|
||||
AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
|
||||
|
||||
let user: RoleName = db_info.user.into();
|
||||
let user_info = ComputeUserInfo {
|
||||
@@ -209,26 +214,12 @@ async fn authenticate(
|
||||
ctx.set_project(db_info.aux.clone());
|
||||
info!("woken up a compute node");
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
if db_info.host.contains("--") {
|
||||
// we need TLS connection with SNI info to properly route it
|
||||
config.ssl_mode(SslMode::Require);
|
||||
} else {
|
||||
config.ssl_mode(SslMode::Disable);
|
||||
}
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password.as_ref());
|
||||
}
|
||||
|
||||
Ok((
|
||||
NodeInfo {
|
||||
config,
|
||||
conn_info,
|
||||
aux: db_info.aux,
|
||||
},
|
||||
auth_info,
|
||||
user_info,
|
||||
db_info.allowed_ips,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
|
||||
use clashmap::ClashMap;
|
||||
use jose_jwk::crypto::KeyInfo;
|
||||
use reqwest::{Client, redirect};
|
||||
@@ -347,17 +349,17 @@ impl JwkCacheEntryLock {
|
||||
.split_once('.')
|
||||
.ok_or(JwtEncodingError::InvalidCompactForm)?;
|
||||
|
||||
let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
|
||||
let header = BASE64_URL_SAFE_NO_PAD.decode(header)?;
|
||||
let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
|
||||
|
||||
let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
|
||||
let payloadb = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
|
||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
|
||||
|
||||
if let Some(iss) = &payload.issuer {
|
||||
ctx.set_jwt_issuer(iss.as_ref().to_owned());
|
||||
}
|
||||
|
||||
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
|
||||
let sig = BASE64_URL_SAFE_NO_PAD.decode(signature)?;
|
||||
|
||||
let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
|
||||
|
||||
@@ -796,7 +798,6 @@ mod tests {
|
||||
use std::net::SocketAddr;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use base64::URL_SAFE_NO_PAD;
|
||||
use bytes::Bytes;
|
||||
use http::Response;
|
||||
use http_body_util::Full;
|
||||
@@ -871,9 +872,8 @@ mod tests {
|
||||
key_id: Some(Cow::Owned(kid)),
|
||||
};
|
||||
|
||||
let header =
|
||||
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
|
||||
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
|
||||
let header = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
|
||||
let body = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&body).unwrap());
|
||||
|
||||
format!("{header}.{body}")
|
||||
}
|
||||
@@ -883,7 +883,7 @@ mod tests {
|
||||
|
||||
let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
|
||||
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
|
||||
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
|
||||
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
@@ -893,7 +893,7 @@ mod tests {
|
||||
|
||||
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
|
||||
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
|
||||
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
|
||||
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
@@ -904,7 +904,7 @@ mod tests {
|
||||
|
||||
let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
|
||||
let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
|
||||
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
|
||||
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use postgres_client::config::SslMode;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::jwt::{AuthRule, FetchAuthRules};
|
||||
use crate::auth::backend::jwt::FetchAuthRulesError;
|
||||
use crate::compute::ConnCfg;
|
||||
use crate::compute::ConnectInfo;
|
||||
use crate::compute_ctl::ComputeCtlApi;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::NodeInfo;
|
||||
@@ -29,7 +30,12 @@ impl LocalBackend {
|
||||
api: http::Endpoint::new(compute_ctl, http::new_client()),
|
||||
},
|
||||
node_info: NodeInfo {
|
||||
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
|
||||
conn_info: ConnectInfo {
|
||||
host_addr: Some(postgres_addr.ip()),
|
||||
host: postgres_addr.ip().to_string().into(),
|
||||
port: postgres_addr.port(),
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
// TODO(conrad): make this better reflect compute info rather than endpoint info.
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
|
||||
|
||||
@@ -14,20 +14,21 @@ use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
|
||||
use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
use crate::control_plane::errors::GetAuthInfoError;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{
|
||||
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
|
||||
RoleAccessControl,
|
||||
};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::pglb::connect_compute::ComputeConnectBackend;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::Stream;
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
@@ -168,8 +169,6 @@ impl ComputeUserInfo {
|
||||
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) enum ComputeCredentialKeys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
JwtPayload(Vec<u8>),
|
||||
None,
|
||||
@@ -232,11 +231,8 @@ async fn auth_quirks(
|
||||
config.is_vpc_acccess_proxy,
|
||||
)?;
|
||||
|
||||
let endpoint = EndpointIdInt::from(&info.endpoint);
|
||||
let rate_limit_config = None;
|
||||
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?;
|
||||
|
||||
let role_access = api
|
||||
.get_role_access_control(ctx, &info.endpoint, &info.user)
|
||||
.await?;
|
||||
@@ -403,29 +399,23 @@ impl Backend<'_, ComputeUserInfo> {
|
||||
allowed_ips: Arc::new(vec![]),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::ControlPlane(_, creds) => &creds.keys,
|
||||
Self::Local(_) => &ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -448,6 +438,7 @@ mod tests {
|
||||
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{
|
||||
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
|
||||
};
|
||||
@@ -486,6 +477,7 @@ mod tests {
|
||||
allowed_ips: Arc::new(self.ips.clone()),
|
||||
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
|
||||
flags: self.access_blocker_flags,
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -169,13 +169,6 @@ pub(crate) async fn validate_password_and_exchange(
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
// test only
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
|
||||
password.to_owned(),
|
||||
)))
|
||||
}
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
|
||||
@@ -28,10 +28,9 @@ use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{
|
||||
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
|
||||
};
|
||||
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
|
||||
@@ -11,11 +11,13 @@ use anyhow::Context;
|
||||
use anyhow::{bail, ensure};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use futures::future::Either;
|
||||
use itertools::{Itertools, Position};
|
||||
use rand::{Rng, thread_rng};
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, info, warn};
|
||||
use tracing::{Instrument, error, info, warn};
|
||||
use utils::sentry_init::init_sentry;
|
||||
use utils::{project_build_tag, project_git_version};
|
||||
|
||||
@@ -314,7 +316,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let jemalloc = match crate::jemalloc::MetricRecorder::new() {
|
||||
Ok(t) => Some(t),
|
||||
Err(e) => {
|
||||
tracing::error!(error = ?e, "could not start jemalloc metrics loop");
|
||||
error!(error = ?e, "could not start jemalloc metrics loop");
|
||||
None
|
||||
}
|
||||
};
|
||||
@@ -520,23 +522,44 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
|
||||
// This prevents immediate exit and pod restart,
|
||||
// which can cause hammering of the redis in case of connection issues.
|
||||
if let Some(mut redis_kv_client) = redis_kv_client {
|
||||
maintenance_tasks.spawn(async move {
|
||||
redis_kv_client.try_connect().await?;
|
||||
handle_cancel_messages(
|
||||
&mut redis_kv_client,
|
||||
rx_cancel,
|
||||
args.cancellation_batch_size,
|
||||
)
|
||||
.await?;
|
||||
for attempt in (0..3).with_position() {
|
||||
match redis_kv_client.try_connect().await {
|
||||
Ok(()) => {
|
||||
info!("Connected to Redis KV client");
|
||||
maintenance_tasks.spawn(async move {
|
||||
handle_cancel_messages(
|
||||
&mut redis_kv_client,
|
||||
rx_cancel,
|
||||
args.cancellation_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(redis_kv_client);
|
||||
drop(redis_kv_client);
|
||||
|
||||
// `handle_cancel_messages` was terminated due to the tx_cancel
|
||||
// being dropped. this is not worthy of an error, and this task can only return `Err`,
|
||||
// so let's wait forever instead.
|
||||
std::future::pending().await
|
||||
});
|
||||
// `handle_cancel_messages` was terminated due to the tx_cancel
|
||||
// being dropped. this is not worthy of an error, and this task can only return `Err`,
|
||||
// so let's wait forever instead.
|
||||
std::future::pending().await
|
||||
});
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to connect to Redis KV client: {e}");
|
||||
if matches!(attempt, Position::Last(_)) {
|
||||
bail!(
|
||||
"Failed to connect to Redis KV client after {} attempts",
|
||||
attempt.into_inner()
|
||||
);
|
||||
}
|
||||
let jitter = thread_rng().gen_range(0..100);
|
||||
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
|
||||
12
proxy/src/cache/project_info.rs
vendored
12
proxy/src/cache/project_info.rs
vendored
@@ -18,6 +18,7 @@ use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ProjectInfoCache {
|
||||
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
|
||||
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
|
||||
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
|
||||
@@ -100,6 +101,13 @@ pub struct ProjectInfoCacheImpl {
|
||||
|
||||
#[async_trait]
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
|
||||
info!("invalidating endpoint access for `{endpoint_id}`");
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_endpoint();
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
|
||||
info!("invalidating endpoint access for project `{project_id}`");
|
||||
let endpoints = self
|
||||
@@ -356,6 +364,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
|
||||
use crate::scram::ServerSecret;
|
||||
use crate::types::ProjectId;
|
||||
@@ -391,6 +400,7 @@ mod tests {
|
||||
allowed_ips: allowed_ips.clone(),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret1.clone(),
|
||||
@@ -406,6 +416,7 @@ mod tests {
|
||||
allowed_ips: allowed_ips.clone(),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret2.clone(),
|
||||
@@ -431,6 +442,7 @@ mod tests {
|
||||
allowed_ips: allowed_ips.clone(),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret3.clone(),
|
||||
|
||||
@@ -24,7 +24,6 @@ use crate::pqproto::CancelKeyData;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::keys::KeyPrefix;
|
||||
use crate::redis::kv_ops::RedisKVClient;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
type IpSubnetKey = IpNet;
|
||||
|
||||
@@ -497,10 +496,8 @@ impl CancelClosure {
|
||||
) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
|
||||
let mut mk_tls =
|
||||
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
compute_config,
|
||||
&self.hostname,
|
||||
)
|
||||
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
mod tls;
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use postgres_client::config::{AuthKeys, SslMode};
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_client::{CancelToken, NoTls, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::parse_endpoint_param;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::compute::tls::TlsError;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
@@ -25,7 +28,6 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::types::Host;
|
||||
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
@@ -38,10 +40,7 @@ pub(crate) enum ConnectionError {
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] InvalidDnsNameError),
|
||||
TlsError(#[from] TlsError),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
WakeComputeError(#[from] WakeComputeError),
|
||||
@@ -73,7 +72,7 @@ impl UserFacingError for ConnectionError {
|
||||
ConnectionError::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
_ => COULD_NOT_CONNECT.to_owned(),
|
||||
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,7 +84,6 @@ impl ReportableError for ConnectionError {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
@@ -96,34 +94,85 @@ impl ReportableError for ConnectionError {
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `postgres_client` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
|
||||
pub enum Auth {
|
||||
/// Only used during console-redirect.
|
||||
Password(Vec<u8>),
|
||||
/// Used by sql-over-http, ws, tcp.
|
||||
Scram(Box<ScramKeys>),
|
||||
}
|
||||
|
||||
/// A config for authenticating to the compute node.
|
||||
pub(crate) struct AuthInfo {
|
||||
/// None for local-proxy, as we use trust-based localhost auth.
|
||||
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
|
||||
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
|
||||
auth: Option<Auth>,
|
||||
server_params: StartupMessageParams,
|
||||
|
||||
/// Console redirect sets user and database, we shouldn't re-use those from the params.
|
||||
skip_db_user: bool,
|
||||
}
|
||||
|
||||
/// Contains only the data needed to establish a secure connection to compute.
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectInfo {
|
||||
pub host_addr: Option<IpAddr>,
|
||||
pub host: Host,
|
||||
pub port: u16,
|
||||
pub ssl_mode: SslMode,
|
||||
}
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
pub(crate) fn new(host: String, port: u16) -> Self {
|
||||
Self(Box::new(postgres_client::Config::new(host, port)))
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub(crate) fn reuse_password(&mut self, other: Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
|
||||
if let Some(keys) = other.get_auth_keys() {
|
||||
self.auth_keys(keys);
|
||||
impl AuthInfo {
|
||||
pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
|
||||
let mut server_params = StartupMessageParams::default();
|
||||
server_params.insert("database", db);
|
||||
server_params.insert("user", user);
|
||||
Self {
|
||||
auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
|
||||
server_params,
|
||||
skip_db_user: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_host(&self) -> Host {
|
||||
match self.0.get_host() {
|
||||
postgres_client::config::Host::Tcp(s) => s.into(),
|
||||
pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
|
||||
Self {
|
||||
auth: match keys {
|
||||
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
|
||||
Some(Auth::Scram(Box::new(auth_keys)))
|
||||
}
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
|
||||
},
|
||||
server_params: StartupMessageParams::default(),
|
||||
skip_db_user: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
pub fn to_postgres_client_config(&self) -> postgres_client::Config {
|
||||
let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
|
||||
config.ssl_mode(self.ssl_mode);
|
||||
if let Some(host_addr) = self.host_addr {
|
||||
config.set_host_addr(host_addr);
|
||||
}
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthInfo {
|
||||
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
|
||||
match &self.auth {
|
||||
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
|
||||
Some(Auth::Password(pw)) => config.password(pw),
|
||||
None => &mut config,
|
||||
};
|
||||
for (k, v) in self.server_params.iter() {
|
||||
config.set_param(k, v);
|
||||
}
|
||||
config
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub(crate) fn set_startup_params(
|
||||
@@ -132,27 +181,26 @@ impl ConnCfg {
|
||||
arbitrary_params: bool,
|
||||
) {
|
||||
if !arbitrary_params {
|
||||
self.set_param("client_encoding", "UTF8");
|
||||
self.server_params.insert("client_encoding", "UTF8");
|
||||
}
|
||||
for (k, v) in params.iter() {
|
||||
match k {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Console redirect auth flow takes username from the console's response.
|
||||
"user" if self.user_is_set() => {}
|
||||
"database" if self.db_is_set() => {}
|
||||
"user" | "database" if self.skip_db_user => {}
|
||||
"options" => {
|
||||
if let Some(options) = filtered_options(v) {
|
||||
self.set_param(k, &options);
|
||||
self.server_params.insert(k, &options);
|
||||
}
|
||||
}
|
||||
"user" | "database" | "application_name" | "replication" => {
|
||||
self.set_param(k, v);
|
||||
self.server_params.insert(k, v);
|
||||
}
|
||||
|
||||
// if we allow arbitrary params, then we forward them through.
|
||||
// this is a flag for a period of backwards compatibility
|
||||
k if arbitrary_params => {
|
||||
self.set_param(k, v);
|
||||
self.server_params.insert(k, v);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@@ -160,25 +208,13 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = postgres_client::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// For now, let's make it easier to setup the config.
|
||||
impl std::ops::DerefMut for ConnCfg {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
|
||||
use postgres_client::config::Host;
|
||||
impl ConnectInfo {
|
||||
/// Establish a raw TCP+TLS connection to the compute node.
|
||||
async fn connect_raw(
|
||||
&self,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
|
||||
let timeout = config.timeout;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
let connect_with_timeout = |addrs| {
|
||||
@@ -208,34 +244,32 @@ impl ConnCfg {
|
||||
// We can't reuse connection establishing logic from `postgres_client` here,
|
||||
// because it has no means for extracting the underlying socket which we
|
||||
// require for our business.
|
||||
let port = self.0.get_port();
|
||||
let host = self.0.get_host();
|
||||
let port = self.port;
|
||||
let host = &*self.host;
|
||||
|
||||
let host = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
|
||||
let addrs = match self.0.get_host_addr() {
|
||||
let addrs = match self.host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => lookup_host((host, port)).await?.collect(),
|
||||
};
|
||||
|
||||
match connect_once(&*addrs).await {
|
||||
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
|
||||
Ok((sockaddr, stream)) => Ok((
|
||||
sockaddr,
|
||||
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
|
||||
)),
|
||||
Err(err) => {
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
Err(err)
|
||||
Err(TlsError::Connection(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub(crate) stream:
|
||||
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
@@ -248,28 +282,23 @@ pub(crate) struct PostgresConnection {
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
impl ConnectInfo {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
aux: MetricsAuxInfo,
|
||||
auth: &AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
|
||||
drop(pause);
|
||||
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
|
||||
// we setup SSL early in `ConnectInfo::connect_raw`.
|
||||
tmp_config.ssl_mode(SslMode::Disable);
|
||||
|
||||
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
)?;
|
||||
|
||||
// connect_raw() will not use TLS if sslmode is "disable"
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let connection = self.0.connect_raw(stream, tls).await?;
|
||||
let (socket_addr, stream) = self.connect_raw(config).await?;
|
||||
let connection = tmp_config.connect_raw(stream, NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
let RawConnection {
|
||||
@@ -282,13 +311,14 @@ impl ConnCfg {
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
|
||||
let stream = stream.into_inner();
|
||||
let MaybeTlsStream::Raw(stream) = stream.into_inner();
|
||||
|
||||
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
|
||||
info!(
|
||||
cold_start_info = ctx.cold_start_info().as_str(),
|
||||
"connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
|
||||
self.0.get_ssl_mode(),
|
||||
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
|
||||
self.host,
|
||||
self.ssl_mode,
|
||||
ctx.get_proxy_latency(),
|
||||
ctx.get_testodrome_id().unwrap_or_default(),
|
||||
);
|
||||
@@ -299,11 +329,11 @@ impl ConnCfg {
|
||||
socket_addr,
|
||||
CancelToken {
|
||||
socket_config: None,
|
||||
ssl_mode: self.0.get_ssl_mode(),
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
host.to_string(),
|
||||
self.host.to_string(),
|
||||
user_info,
|
||||
);
|
||||
|
||||
63
proxy/src/compute/tls.rs
Normal file
63
proxy/src/compute/tls.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use futures::FutureExt;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::pqproto::request_tls;
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TlsError {
|
||||
#[error(transparent)]
|
||||
Dns(#[from] InvalidDnsNameError),
|
||||
#[error(transparent)]
|
||||
Connection(#[from] std::io::Error),
|
||||
#[error("TLS required but not provided")]
|
||||
Required,
|
||||
}
|
||||
|
||||
impl CouldRetry for TlsError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
TlsError::Dns(_) => false,
|
||||
TlsError::Connection(err) => err.could_retry(),
|
||||
// perhaps compute didn't realise it supports TLS?
|
||||
TlsError::Required => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect_tls<S, T>(
|
||||
mut stream: S,
|
||||
mode: SslMode,
|
||||
tls: &T,
|
||||
host: &str,
|
||||
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: MakeTlsConnect<
|
||||
S,
|
||||
Error = InvalidDnsNameError,
|
||||
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
|
||||
>,
|
||||
{
|
||||
match mode {
|
||||
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
|
||||
SslMode::Prefer | SslMode::Require => {}
|
||||
}
|
||||
|
||||
if !request_tls(&mut stream).await? {
|
||||
if SslMode::Require == mode {
|
||||
return Err(TlsError::Required);
|
||||
}
|
||||
|
||||
return Ok(MaybeTlsStream::Raw(stream));
|
||||
}
|
||||
|
||||
Ok(MaybeTlsStream::Tls(
|
||||
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
|
||||
))
|
||||
}
|
||||
@@ -11,13 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::{
|
||||
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
|
||||
};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -210,20 +209,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, user_info, _ip_allowlist) = match backend
|
||||
let (node_info, mut auth_info, user_info) = match backend
|
||||
.authenticate(ctx, &config.authentication_config, &mut stream)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
auth_info.set_startup_params(¶ms, true);
|
||||
|
||||
let node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
params_compat: true,
|
||||
params: ¶ms,
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&node_info,
|
||||
|
||||
@@ -146,6 +146,7 @@ impl NeonControlPlaneClient {
|
||||
public_access_blocked: block_public_connections,
|
||||
vpc_access_blocked: block_vpc_connections,
|
||||
},
|
||||
rate_limits: body.rate_limits,
|
||||
})
|
||||
}
|
||||
.inspect_err(|e| tracing::debug!(error = ?e))
|
||||
@@ -261,24 +262,18 @@ impl NeonControlPlaneClient {
|
||||
Some(_) => SslMode::Require,
|
||||
None => SslMode::Disable,
|
||||
};
|
||||
let host_name = match body.server_name {
|
||||
Some(host) => host,
|
||||
None => host.to_owned(),
|
||||
let host = match body.server_name {
|
||||
Some(host) => host.into(),
|
||||
None => host.into(),
|
||||
};
|
||||
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new(host_name, port);
|
||||
|
||||
if let Some(addr) = host_addr {
|
||||
config.set_host_addr(addr);
|
||||
}
|
||||
|
||||
config.ssl_mode(ssl_mode);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
conn_info: compute::ConnectInfo {
|
||||
host_addr,
|
||||
host,
|
||||
port,
|
||||
ssl_mode,
|
||||
},
|
||||
aux: body.aux,
|
||||
};
|
||||
|
||||
@@ -318,6 +313,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
@@ -363,6 +359,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use postgres_client::config::SslMode;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::Client;
|
||||
use tracing::{Instrument, error, info, info_span, warn};
|
||||
@@ -14,19 +15,20 @@ use crate::auth::IpPattern;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::cache::Cached;
|
||||
use crate::compute::ConnectInfo;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
|
||||
use crate::control_plane::{
|
||||
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
|
||||
RoleAccessControl,
|
||||
};
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::scram;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{compute, scram};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum MockApiError {
|
||||
@@ -87,8 +89,7 @@ impl MockControlPlane {
|
||||
.await?
|
||||
{
|
||||
info!("got a secret: {entry}"); // safe since it's not a prod scenario
|
||||
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
|
||||
} else {
|
||||
warn!("user '{role}' does not exist");
|
||||
None
|
||||
@@ -129,6 +130,7 @@ impl MockControlPlane {
|
||||
project_id: None,
|
||||
account_id: None,
|
||||
access_blocker_flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -170,25 +172,23 @@ impl MockControlPlane {
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let port = self.endpoint.port().unwrap_or(5432);
|
||||
let mut config = match self.endpoint.host_str() {
|
||||
None => {
|
||||
let mut config = compute::ConnCfg::new("localhost".to_string(), port);
|
||||
config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
|
||||
config
|
||||
}
|
||||
Some(host) => {
|
||||
let mut config = compute::ConnCfg::new(host.to_string(), port);
|
||||
if let Ok(addr) = IpAddr::from_str(host) {
|
||||
config.set_host_addr(addr);
|
||||
}
|
||||
config
|
||||
}
|
||||
let conn_info = match self.endpoint.host_str() {
|
||||
None => ConnectInfo {
|
||||
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
host: "localhost".into(),
|
||||
port,
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
Some(host) => ConnectInfo {
|
||||
host_addr: IpAddr::from_str(host).ok(),
|
||||
host: host.into(),
|
||||
port,
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
};
|
||||
|
||||
config.ssl_mode(postgres_client::config::SslMode::Disable);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
conn_info,
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
@@ -234,6 +234,7 @@ impl super::ControlPlaneApi for MockControlPlane {
|
||||
allowed_ips: Arc::new(info.allowed_ips),
|
||||
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
|
||||
flags: info.access_blocker_flags,
|
||||
rate_limits: info.rate_limits,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -266,12 +267,3 @@ impl super::ControlPlaneApi for MockControlPlane {
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_md5(input: &str) -> Option<[u8; 16]> {
|
||||
let text = input.strip_prefix("md5")?;
|
||||
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use clashmap::ClashMap;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
|
||||
use crate::cache::endpoints::EndpointsCache;
|
||||
@@ -22,8 +23,6 @@ use crate::metrics::ApiLockMetrics;
|
||||
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
|
||||
use crate::types::EndpointId;
|
||||
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ControlPlaneClient {
|
||||
|
||||
@@ -227,12 +227,35 @@ pub(crate) struct UserFacingMessage {
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct GetEndpointAccessControl {
|
||||
pub(crate) role_secret: Box<str>,
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
|
||||
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
pub(crate) account_id: Option<AccountIdInt>,
|
||||
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
|
||||
pub(crate) block_public_connections: Option<bool>,
|
||||
pub(crate) block_vpc_connections: Option<bool>,
|
||||
|
||||
#[serde(default)]
|
||||
pub(crate) rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
pub struct EndpointRateLimitConfig {
|
||||
pub connection_attempts: ConnectionAttemptsLimit,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
pub struct ConnectionAttemptsLimit {
|
||||
pub tcp: Option<LeakyBucketSetting>,
|
||||
pub ws: Option<LeakyBucketSetting>,
|
||||
pub http: Option<LeakyBucketSetting>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize)]
|
||||
pub struct LeakyBucketSetting {
|
||||
pub rps: f64,
|
||||
pub burst: f64,
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
|
||||
@@ -11,15 +11,18 @@ pub(crate) mod errors;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use messages::EndpointRateLimitConfig;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
|
||||
use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
|
||||
use crate::intern::{AccountIdInt, ProjectIdInt};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::{compute, scram};
|
||||
|
||||
@@ -39,10 +42,6 @@ pub mod mgmt;
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) enum AuthSecret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
@@ -60,16 +59,14 @@ pub(crate) struct AuthInfo {
|
||||
pub(crate) account_id: Option<AccountIdInt>,
|
||||
/// Are public connections or VPC connections blocked?
|
||||
pub(crate) access_blocker_flags: AccessBlockerFlags,
|
||||
/// The rate limits for this endpoint.
|
||||
pub(crate) rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub(crate) config: compute::ConnCfg,
|
||||
pub(crate) conn_info: compute::ConnectInfo,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
@@ -79,26 +76,14 @@ impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
auth: &compute::AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(ctx, self.aux.clone(), config, user_info)
|
||||
self.conn_info
|
||||
.connect(ctx, self.aux.clone(), auth, config, user_info)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default)]
|
||||
@@ -121,6 +106,8 @@ pub struct EndpointAccessControl {
|
||||
pub allowed_ips: Arc<Vec<IpPattern>>,
|
||||
pub allowed_vpce: Arc<Vec<String>>,
|
||||
pub flags: AccessBlockerFlags,
|
||||
|
||||
pub rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
impl EndpointAccessControl {
|
||||
@@ -159,6 +146,36 @@ impl EndpointAccessControl {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn connection_attempt_rate_limit(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
rate_limiter: &EndpointRateLimiter,
|
||||
) -> Result<(), AuthError> {
|
||||
let endpoint = EndpointIdInt::from(endpoint);
|
||||
|
||||
let limits = &self.rate_limits.connection_attempts;
|
||||
let config = match ctx.protocol() {
|
||||
crate::metrics::Protocol::Http => limits.http,
|
||||
crate::metrics::Protocol::Ws => limits.ws,
|
||||
crate::metrics::Protocol::Tcp => limits.tcp,
|
||||
crate::metrics::Protocol::SniRouter => return Ok(()),
|
||||
};
|
||||
let config = config.and_then(|config| {
|
||||
if config.rps <= 0.0 || config.burst <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(LeakyBucketConfig::new(config.rps, config.burst))
|
||||
});
|
||||
|
||||
if !rate_limiter.check(endpoint, config, 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
|
||||
@@ -106,4 +106,5 @@ mod tls;
|
||||
mod types;
|
||||
mod url;
|
||||
mod usage_metrics;
|
||||
mod util;
|
||||
mod waiters;
|
||||
|
||||
@@ -610,11 +610,11 @@ pub enum RedisEventsCount {
|
||||
BranchCreated,
|
||||
ProjectCreated,
|
||||
CancelSession,
|
||||
PasswordUpdate,
|
||||
AllowedIpsUpdate,
|
||||
AllowedVpcEndpointIdsUpdateForProjects,
|
||||
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
|
||||
BlockPublicOrVpcAccessUpdate,
|
||||
InvalidateRole,
|
||||
InvalidateEndpoint,
|
||||
InvalidateProject,
|
||||
InvalidateProjects,
|
||||
InvalidateOrg,
|
||||
}
|
||||
|
||||
pub struct ThreadPoolWorkers(usize);
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod connect_compute;
|
||||
pub mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod inprocess;
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::io::{self, Cursor};
|
||||
use bytes::{Buf, BufMut};
|
||||
use itertools::Itertools;
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
pub type ErrorCode = [u8; 5];
|
||||
@@ -53,6 +53,28 @@ impl fmt::Debug for ProtocolVersion {
|
||||
}
|
||||
}
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
|
||||
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
|
||||
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
|
||||
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
|
||||
|
||||
/// This first reads the startup message header, is 8 bytes.
|
||||
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
|
||||
///
|
||||
/// The length value is inclusive of the header. For example,
|
||||
/// an empty message will always have length 8.
|
||||
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
struct StartupHeader {
|
||||
len: big_endian::U32,
|
||||
version: ProtocolVersion,
|
||||
}
|
||||
|
||||
/// read the type from the stream using zerocopy.
|
||||
///
|
||||
/// not cancel safe.
|
||||
@@ -66,32 +88,38 @@ macro_rules! read {
|
||||
}};
|
||||
}
|
||||
|
||||
/// Returns true if TLS is supported.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let payload = StartupHeader {
|
||||
len: 8.into(),
|
||||
version: NEGOTIATE_SSL_CODE,
|
||||
};
|
||||
stream.write_all(payload.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
// we expect back either `S` or `N` as a single byte.
|
||||
let mut res = *b"0";
|
||||
stream.read_exact(&mut res).await?;
|
||||
|
||||
debug_assert!(
|
||||
res == *b"S" || res == *b"N",
|
||||
"unexpected SSL negotiation response: {}",
|
||||
char::from(res[0]),
|
||||
);
|
||||
|
||||
// S for SSL.
|
||||
Ok(res == *b"S")
|
||||
}
|
||||
|
||||
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
|
||||
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
|
||||
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
|
||||
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
|
||||
|
||||
/// This first reads the startup message header, is 8 bytes.
|
||||
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
|
||||
///
|
||||
/// The length value is inclusive of the header. For example,
|
||||
/// an empty message will always have length 8.
|
||||
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
struct StartupHeader {
|
||||
len: big_endian::U32,
|
||||
version: ProtocolVersion,
|
||||
}
|
||||
|
||||
let header = read!(stream => StartupHeader);
|
||||
|
||||
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
|
||||
@@ -564,9 +592,8 @@ mod tests {
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
|
||||
|
||||
use super::ProtocolVersion;
|
||||
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_large_startup() {
|
||||
|
||||
@@ -2,26 +2,25 @@ use async_trait::async_trait;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::control_plane::{self, NodeInfo};
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
|
||||
use crate::types::Host;
|
||||
|
||||
/// If we couldn't connect, a cached connection info might be to blame
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
@@ -48,34 +47,17 @@ pub(crate) trait ConnectMechanism {
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys;
|
||||
}
|
||||
|
||||
pub(crate) struct TcpMechanism<'a> {
|
||||
pub(crate) params_compat: bool,
|
||||
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub(crate) params: &'a StartupMessageParams,
|
||||
|
||||
pub(crate) struct TcpMechanism {
|
||||
pub(crate) auth: AuthInfo,
|
||||
/// connect_to_compute concurrency lock
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
|
||||
pub(crate) user_info: ComputeUserInfo,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TcpMechanism<'_> {
|
||||
impl ConnectMechanism for TcpMechanism {
|
||||
type Connection = PostgresConnection;
|
||||
type ConnectError = compute::ConnectionError;
|
||||
type Error = compute::ConnectionError;
|
||||
@@ -90,19 +72,18 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await)
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
config.set_startup_params(self.params, self.params_compat);
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
permit.release_result(
|
||||
node_info
|
||||
.connect(ctx, &self.auth, config, self.user_info.clone())
|
||||
.await,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
|
||||
ctx: &RequestContext,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
@@ -114,12 +95,9 @@ where
|
||||
M::Error: From<WakeComputeError>,
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info =
|
||||
let node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
|
||||
node_info.set_keys(user_info.get_keys());
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||
Ok(res) => {
|
||||
@@ -155,14 +133,9 @@ where
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
debug!("compute node's state has likely changed; requesting a wake-up");
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
invalidate_cache(node_info);
|
||||
// TODO: increment num_retries?
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
node_info
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
@@ -1,8 +1,10 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub(crate) mod connect_compute;
|
||||
pub(crate) mod retry;
|
||||
pub(crate) mod wake_compute;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
@@ -21,15 +23,16 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::util::run_until_cancelled;
|
||||
use crate::{auth, compute};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
@@ -46,21 +49,6 @@ impl ReportableError for TlsRequired {
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Option<F::Output> {
|
||||
match futures::future::select(
|
||||
std::pin::pin!(f),
|
||||
std::pin::pin!(cancellation_token.cancelled()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
futures::future::Either::Left((f, _)) => Some(f),
|
||||
futures::future::Either::Right(((), _)) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
@@ -358,24 +346,22 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
}
|
||||
};
|
||||
|
||||
let compute_user_info = match &user_info {
|
||||
auth::Backend::ControlPlane(_, info) => &info.info,
|
||||
let (cplane, creds) = match user_info {
|
||||
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
|
||||
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
|
||||
auth_info.set_startup_params(¶ms, params_compat);
|
||||
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: compute_user_info.clone(),
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
user_info: creds.info.clone(),
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
&auth::Backend::ControlPlane(cplane, creds.info),
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
|
||||
@@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.could_retry(),
|
||||
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
|
||||
compute::ConnectionError::TlsError(err) => err.could_retry(),
|
||||
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
|
||||
_ => false,
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,17 +19,14 @@ use tracing_test::traced_test;
|
||||
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
|
||||
};
|
||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::pglb::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
@@ -72,13 +69,14 @@ struct ClientConfig<'a> {
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
|
||||
type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
|
||||
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
Ok(crate::tls::postgres_rustls::make_tls_connect(
|
||||
&self.config,
|
||||
self.hostname,
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,8 +495,6 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
x => panic!("expecting action {x:?}, connect is called instead"),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
impl TestControlPlaneClient for TestConnectMechanism {
|
||||
@@ -557,7 +553,12 @@ impl TestControlPlaneClient for TestConnectMechanism {
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
config: compute::ConnCfg::new("test".to_owned(), 5432),
|
||||
conn_info: compute::ConnectInfo {
|
||||
host: "test".into(),
|
||||
port: 5432,
|
||||
ssl_mode: SslMode::Disable,
|
||||
host_addr: None,
|
||||
},
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
@@ -572,16 +573,13 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> auth::Backend<'static, ComputeCredentials> {
|
||||
) -> auth::Backend<'static, ComputeUserInfo> {
|
||||
auth::Backend::ControlPlane(
|
||||
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use async_trait::async_trait;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::config::RetryConfig;
|
||||
@@ -8,7 +9,6 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
use crate::pglb::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::retry::{retry_after, should_retry};
|
||||
|
||||
// Use macro to retain original callsite.
|
||||
@@ -23,7 +23,12 @@ macro_rules! log_wake_compute_error {
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
|
||||
#[async_trait]
|
||||
pub(crate) trait WakeComputeBackend {
|
||||
async fn wake_compute(&self, ctx: &RequestContext) -> Result<CachedNodeInfo, WakeComputeError>;
|
||||
}
|
||||
|
||||
pub(crate) async fn wake_compute<B: WakeComputeBackend>(
|
||||
num_retries: &mut u32,
|
||||
ctx: &RequestContext,
|
||||
api: &B,
|
||||
|
||||
@@ -69,9 +69,8 @@ pub struct LeakyBucketConfig {
|
||||
pub max: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl LeakyBucketConfig {
|
||||
pub(crate) fn new(rps: f64, max: f64) -> Self {
|
||||
pub fn new(rps: f64, max: f64) -> Self {
|
||||
assert!(rps > 0.0, "rps must be positive");
|
||||
assert!(max > 0.0, "max must be positive");
|
||||
Self { rps, max }
|
||||
|
||||
@@ -12,11 +12,10 @@ use rand::{Rng, SeedableRng};
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use super::LeakyBucketConfig;
|
||||
use crate::ext::LockExt;
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
use super::LeakyBucketConfig;
|
||||
|
||||
pub struct GlobalRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: Vec<RateBucketInfo>,
|
||||
|
||||
@@ -3,12 +3,12 @@ use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use redis::aio::PubSub;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::Deserialize;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
@@ -27,42 +27,37 @@ struct NotificationHeader<'a> {
|
||||
topic: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
pub(crate) enum Notification {
|
||||
enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
rename = "/account_settings_update",
|
||||
alias = "/allowed_vpc_endpoints_updated_for_org",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate,
|
||||
},
|
||||
AccountSettingsUpdate(InvalidateAccount),
|
||||
|
||||
#[serde(
|
||||
rename = "/block_public_or_vpc_access_updated",
|
||||
rename = "/endpoint_settings_update",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
BlockPublicOrVpcAccessUpdated {
|
||||
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
|
||||
},
|
||||
EndpointSettingsUpdate(InvalidateEndpoint),
|
||||
|
||||
#[serde(
|
||||
rename = "/allowed_vpc_endpoints_updated_for_org",
|
||||
rename = "/project_settings_update",
|
||||
alias = "/allowed_ips_updated",
|
||||
alias = "/block_public_or_vpc_access_updated",
|
||||
alias = "/allowed_vpc_endpoints_updated_for_projects",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedVpcEndpointsUpdatedForOrg {
|
||||
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
|
||||
},
|
||||
ProjectSettingsUpdate(InvalidateProject),
|
||||
|
||||
#[serde(
|
||||
rename = "/allowed_vpc_endpoints_updated_for_projects",
|
||||
rename = "/role_setting_update",
|
||||
alias = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedVpcEndpointsUpdatedForProjects {
|
||||
allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
RoleSettingUpdate(InvalidateRole),
|
||||
|
||||
#[serde(
|
||||
other,
|
||||
@@ -72,28 +67,56 @@ pub(crate) enum Notification {
|
||||
UnknownTopic,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedIpsUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum InvalidateEndpoint {
|
||||
EndpointId(EndpointIdInt),
|
||||
EndpointIds(Vec<EndpointIdInt>),
|
||||
}
|
||||
impl std::ops::Deref for InvalidateEndpoint {
|
||||
type Target = [EndpointIdInt];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::EndpointId(id) => std::slice::from_ref(id),
|
||||
Self::EndpointIds(ids) => ids,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct BlockPublicOrVpcAccessUpdated {
|
||||
project_id: ProjectIdInt,
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum InvalidateProject {
|
||||
ProjectId(ProjectIdInt),
|
||||
ProjectIds(Vec<ProjectIdInt>),
|
||||
}
|
||||
impl std::ops::Deref for InvalidateProject {
|
||||
type Target = [ProjectIdInt];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::ProjectId(id) => std::slice::from_ref(id),
|
||||
Self::ProjectIds(ids) => ids,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
|
||||
account_id: AccountIdInt,
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum InvalidateAccount {
|
||||
AccountId(AccountIdInt),
|
||||
AccountIds(Vec<AccountIdInt>),
|
||||
}
|
||||
impl std::ops::Deref for InvalidateAccount {
|
||||
type Target = [AccountIdInt];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
Self::AccountId(id) => std::slice::from_ref(id),
|
||||
Self::AccountIds(ids) => ids,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
|
||||
project_ids: Vec<ProjectIdInt>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct PasswordUpdate {
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct InvalidateRole {
|
||||
project_id: ProjectIdInt,
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
@@ -177,41 +200,29 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Notification::AllowedIpsUpdate { .. }
|
||||
| Notification::PasswordUpdate { .. }
|
||||
| Notification::BlockPublicOrVpcAccessUpdated { .. }
|
||||
| Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
|
||||
| Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
|
||||
Notification::RoleSettingUpdate { .. }
|
||||
| Notification::EndpointSettingsUpdate { .. }
|
||||
| Notification::ProjectSettingsUpdate { .. }
|
||||
| Notification::AccountSettingsUpdate { .. } => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedIpsUpdate);
|
||||
} else if matches!(msg, Notification::PasswordUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::PasswordUpdate);
|
||||
} else if matches!(
|
||||
msg,
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
|
||||
) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
|
||||
} else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
|
||||
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
|
||||
|
||||
let m = &Metrics::get().proxy.redis_events_count;
|
||||
match msg {
|
||||
Notification::RoleSettingUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateRole);
|
||||
}
|
||||
Notification::EndpointSettingsUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateEndpoint);
|
||||
}
|
||||
Notification::ProjectSettingsUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateProject);
|
||||
}
|
||||
Notification::AccountSettingsUpdate { .. } => {
|
||||
m.inc(RedisEventsCount::InvalidateOrg);
|
||||
}
|
||||
Notification::UnknownTopic => {}
|
||||
}
|
||||
|
||||
// TODO: add additional metrics for the other event types.
|
||||
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
@@ -233,30 +244,23 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
match msg {
|
||||
Notification::AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate { project_id },
|
||||
}
|
||||
| Notification::BlockPublicOrVpcAccessUpdated {
|
||||
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id },
|
||||
} => cache.invalidate_endpoint_access_for_project(project_id),
|
||||
Notification::AllowedVpcEndpointsUpdatedForOrg {
|
||||
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id },
|
||||
} => cache.invalidate_endpoint_access_for_org(account_id),
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects {
|
||||
allowed_vpc_endpoints_updated_for_projects:
|
||||
AllowedVpcEndpointsUpdatedForProjects { project_ids },
|
||||
} => {
|
||||
for project in project_ids {
|
||||
cache.invalidate_endpoint_access_for_project(project);
|
||||
}
|
||||
}
|
||||
Notification::PasswordUpdate {
|
||||
password_update:
|
||||
PasswordUpdate {
|
||||
project_id,
|
||||
role_name,
|
||||
},
|
||||
} => cache.invalidate_role_secret_for_project(project_id, role_name),
|
||||
Notification::EndpointSettingsUpdate(ids) => ids
|
||||
.iter()
|
||||
.for_each(|&id| cache.invalidate_endpoint_access(id)),
|
||||
|
||||
Notification::AccountSettingsUpdate(ids) => ids
|
||||
.iter()
|
||||
.for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
|
||||
|
||||
Notification::ProjectSettingsUpdate(ids) => ids
|
||||
.iter()
|
||||
.for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
|
||||
|
||||
Notification::RoleSettingUpdate(InvalidateRole {
|
||||
project_id,
|
||||
role_name,
|
||||
}) => cache.invalidate_role_secret_for_project(project_id, role_name),
|
||||
|
||||
Notification::UnknownTopic => unreachable!(),
|
||||
}
|
||||
}
|
||||
@@ -353,11 +357,32 @@ mod tests {
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate {
|
||||
project_id: (&project_id).into()
|
||||
}
|
||||
}
|
||||
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multiple_projects() -> anyhow::Result<()> {
|
||||
let project_id1: ProjectId = "new_project1".into();
|
||||
let project_id2: ProjectId = "new_project2".into();
|
||||
let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/allowed_vpc_endpoints_updated_for_projects",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
|
||||
(&project_id1).into(),
|
||||
(&project_id2).into()
|
||||
]))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -379,12 +404,10 @@ mod tests {
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::PasswordUpdate {
|
||||
password_update: PasswordUpdate {
|
||||
project_id: (&project_id).into(),
|
||||
role_name: (&role_name).into(),
|
||||
}
|
||||
}
|
||||
Notification::RoleSettingUpdate(InvalidateRole {
|
||||
project_id: (&project_id).into(),
|
||||
role_name: (&role_name).into(),
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
//! Definition and parser for channel binding flag (a part of the `GS2` header).
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
|
||||
/// Channel binding flag (possibly with params).
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub(crate) enum ChannelBinding<T> {
|
||||
@@ -55,7 +58,7 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
let mut cbind_input = vec![];
|
||||
write!(&mut cbind_input, "p={mode},,",).unwrap();
|
||||
cbind_input.extend_from_slice(get_cbind_data(mode)?);
|
||||
base64::encode(&cbind_input).into()
|
||||
BASE64_STANDARD.encode(&cbind_input).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -70,9 +73,9 @@ mod tests {
|
||||
use ChannelBinding::*;
|
||||
|
||||
let cases = [
|
||||
(NotSupportedClient, base64::encode("n,,")),
|
||||
(NotSupportedServer, base64::encode("y,,")),
|
||||
(Required("foo"), base64::encode("p=foo,,bar")),
|
||||
(NotSupportedClient, BASE64_STANDARD.encode("n,,")),
|
||||
(NotSupportedServer, BASE64_STANDARD.encode("y,,")),
|
||||
(Required("foo"), BASE64_STANDARD.encode("p=foo,,bar")),
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
@@ -105,7 +107,7 @@ pub(crate) async fn exchange(
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
let salt = base64::decode(&secret.salt_base64)?;
|
||||
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
|
||||
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
|
||||
if secret.is_password_invalid(&client_key).into() {
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
use std::fmt;
|
||||
use std::ops::Range;
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::{SCRAM_KEY_LEN, ScramKey};
|
||||
use super::signature::SignatureBuilder;
|
||||
@@ -88,7 +91,7 @@ impl<'a> ClientFirstMessage<'a> {
|
||||
|
||||
let mut message = String::new();
|
||||
write!(&mut message, "r={}", self.nonce).unwrap();
|
||||
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
|
||||
BASE64_STANDARD.encode_string(nonce, &mut message);
|
||||
let combined_nonce = 2..message.len();
|
||||
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
|
||||
|
||||
@@ -142,11 +145,7 @@ impl<'a> ClientFinalMessage<'a> {
|
||||
server_key: &ScramKey,
|
||||
) -> String {
|
||||
let mut buf = String::from("v=");
|
||||
base64::encode_config_buf(
|
||||
signature_builder.build(server_key),
|
||||
base64::STANDARD,
|
||||
&mut buf,
|
||||
);
|
||||
BASE64_STANDARD.encode_string(signature_builder.build(server_key), &mut buf);
|
||||
|
||||
buf
|
||||
}
|
||||
@@ -251,7 +250,7 @@ mod tests {
|
||||
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(msg.proof),
|
||||
BASE64_STANDARD.encode(msg.proof),
|
||||
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
|
||||
);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ mod secret;
|
||||
mod signature;
|
||||
pub mod threadpool;
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
pub(crate) use exchange::{Exchange, exchange};
|
||||
use hmac::{Hmac, Mac};
|
||||
pub(crate) use key::ScramKey;
|
||||
@@ -32,7 +34,7 @@ pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
let mut bytes = [0u8; N];
|
||||
|
||||
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
|
||||
let size = BASE64_STANDARD.decode_slice(input, &mut bytes).ok()?;
|
||||
if size != N {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
//! Tools for SCRAM server secret management.
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use subtle::{Choice, ConstantTimeEq};
|
||||
|
||||
use super::base64_decode_array;
|
||||
@@ -56,7 +58,7 @@ impl ServerSecret {
|
||||
// iteration count 1 for our generated passwords going forward.
|
||||
// PG16 users can set iteration count=1 already today.
|
||||
iterations: 1,
|
||||
salt_base64: base64::encode(nonce),
|
||||
salt_base64: BASE64_STANDARD.encode(nonce),
|
||||
stored_key: ScramKey::default(),
|
||||
server_key: ScramKey::default(),
|
||||
doomed: true,
|
||||
@@ -88,7 +90,7 @@ mod tests {
|
||||
assert_eq!(parsed.iterations, iterations);
|
||||
assert_eq!(parsed.salt_base64, salt);
|
||||
|
||||
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
|
||||
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,9 +21,8 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
|
||||
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, AuthError};
|
||||
use crate::compute;
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
};
|
||||
@@ -35,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::pglb::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
|
||||
@@ -69,17 +68,20 @@ impl PoolingBackend {
|
||||
self.config.authentication_config.is_vpc_acccess_proxy,
|
||||
)?;
|
||||
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let rate_limit_config = None;
|
||||
if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
access_control.connection_attempt_rate_limit(
|
||||
ctx,
|
||||
&user_info.endpoint,
|
||||
&self.endpoint_rate_limiter,
|
||||
)?;
|
||||
|
||||
let role_access = backend.get_role_secret(ctx).await?;
|
||||
let Some(secret) = role_access.secret else {
|
||||
// If we don't have an authentication secret, for the http flow we can just return an error.
|
||||
info!("authentication info not found");
|
||||
return Err(AuthError::password_failed(&*user_info.user));
|
||||
};
|
||||
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&self.config.authentication_config.thread_pool,
|
||||
ep,
|
||||
@@ -181,14 +183,15 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||
crate::pglb::connect_compute::connect_to_compute(
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys.info);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
keys: keys.keys,
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
@@ -215,18 +218,15 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!(
|
||||
"{}{LOCAL_PROXY_SUFFIX}",
|
||||
conn_info.user_info.endpoint.normalize()
|
||||
)),
|
||||
options: conn_info.user_info.options.clone(),
|
||||
},
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!(
|
||||
"{}{LOCAL_PROXY_SUFFIX}",
|
||||
conn_info.user_info.endpoint.normalize()
|
||||
)),
|
||||
options: conn_info.user_info.options.clone(),
|
||||
});
|
||||
crate::pglb::connect_compute::connect_to_compute(
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
@@ -305,12 +305,13 @@ impl PoolingBackend {
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let mut node_info = local_backend.node_info.clone();
|
||||
|
||||
let (key, jwk) = create_random_jwk();
|
||||
|
||||
let config = node_info
|
||||
.config
|
||||
let mut config = local_backend
|
||||
.node_info
|
||||
.conn_info
|
||||
.to_postgres_client_config();
|
||||
config
|
||||
.user(&conn_info.user_info.user)
|
||||
.dbname(&conn_info.dbname)
|
||||
.set_param(
|
||||
@@ -322,7 +323,7 @@ impl PoolingBackend {
|
||||
);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(postgres_client::NoTls).await?;
|
||||
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
let pid = client.get_process_id();
|
||||
@@ -336,7 +337,7 @@ impl PoolingBackend {
|
||||
connection,
|
||||
key,
|
||||
conn_id,
|
||||
node_info.aux.clone(),
|
||||
local_backend.node_info.aux.clone(),
|
||||
);
|
||||
|
||||
{
|
||||
@@ -495,6 +496,7 @@ struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
keys: ComputeCredentialKeys,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
@@ -512,19 +514,20 @@ impl ConnectMechanism for TokioMechanism {
|
||||
node_info: &CachedNodeInfo,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
|
||||
let mut config = (*node_info.config).clone();
|
||||
let mut config = node_info.conn_info.to_postgres_client_config();
|
||||
let config = config
|
||||
.user(&self.conn_info.user_info.user)
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(compute_config.timeout);
|
||||
|
||||
let mk_tls =
|
||||
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
|
||||
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
|
||||
config.auth_keys(auth_keys);
|
||||
}
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(mk_tls).await;
|
||||
let res = config.connect(compute_config).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -548,8 +551,6 @@ impl ConnectMechanism for TokioMechanism {
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
@@ -573,20 +574,20 @@ impl ConnectMechanism for HyperMechanism {
|
||||
node_info: &CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host_addr = node_info.config.get_host_addr();
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
let host_addr = node_info.conn_info.host_addr;
|
||||
let host = &node_info.conn_info.host;
|
||||
let permit = self.locks.get_permit(host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
|
||||
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
|
||||
None
|
||||
} else {
|
||||
Some(&config.tls)
|
||||
};
|
||||
|
||||
let port = node_info.config.get_port();
|
||||
let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
|
||||
let port = node_info.conn_info.port;
|
||||
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -609,8 +610,6 @@ impl ConnectMechanism for HyperMechanism {
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
async fn connect_http2(
|
||||
|
||||
@@ -23,12 +23,12 @@ use super::conn_pool_lib::{
|
||||
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
|
||||
GlobalConnPool,
|
||||
};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
|
||||
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ConnInfoWithAuth {
|
||||
|
||||
@@ -16,6 +16,8 @@ use std::sync::atomic::AtomicUsize;
|
||||
use std::task::{Poll, ready};
|
||||
use std::time::Duration;
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
|
||||
use ed25519_dalek::{Signature, Signer, SigningKey};
|
||||
use futures::Future;
|
||||
use futures::future::poll_fn;
|
||||
@@ -346,7 +348,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
|
||||
jwt.push_str("eyJhbGciOiJFZERTQSJ9.");
|
||||
|
||||
// encode the jwt payload in-place
|
||||
base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt);
|
||||
BASE64_URL_SAFE_NO_PAD.encode_string(payload, &mut jwt);
|
||||
|
||||
// create the signature from the encoded header || payload
|
||||
let sig: Signature = sk.sign(jwt.as_bytes());
|
||||
@@ -354,7 +356,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
|
||||
jwt.push('.');
|
||||
|
||||
// encode the jwt signature in-place
|
||||
base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt);
|
||||
BASE64_URL_SAFE_NO_PAD.encode_string(sig.to_bytes(), &mut jwt);
|
||||
|
||||
debug_assert_eq!(
|
||||
jwt.len(),
|
||||
|
||||
@@ -50,10 +50,10 @@ use crate::context::RequestContext;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::serverless::http_util::{api_error_into_response, json_response};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
|
||||
|
||||
@@ -41,10 +41,11 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::{NeonOptions, run_until_cancelled};
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
@@ -3,6 +3,8 @@ pub mod postgres_rustls;
|
||||
pub mod server_config;
|
||||
|
||||
use anyhow::Context;
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{error, info};
|
||||
@@ -58,7 +60,7 @@ impl TlsServerEndPoint {
|
||||
let oid = certificate.signature_algorithm.oid;
|
||||
if SHA256_OIDS.contains(&oid) {
|
||||
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
|
||||
info!(%subject, tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
info!(%subject, tls_server_end_point = %BASE64_STANDARD.encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(%subject, "unknown channel binding");
|
||||
|
||||
@@ -2,10 +2,11 @@ use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use rustls::ClientConfig;
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::pki_types::{InvalidDnsNameError, ServerName};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::config::ComputeConfig;
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
@@ -123,36 +124,27 @@ mod private {
|
||||
}
|
||||
}
|
||||
|
||||
/// A `MakeTlsConnect` implementation using `rustls`.
|
||||
///
|
||||
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
||||
#[derive(Clone)]
|
||||
pub struct MakeRustlsConnect {
|
||||
pub config: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
impl MakeRustlsConnect {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
||||
#[must_use]
|
||||
pub fn new(config: Arc<ClientConfig>) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
|
||||
impl<S> MakeTlsConnect<S> for ComputeConfig
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = private::RustlsStream<S>;
|
||||
type TlsConnect = private::RustlsConnect;
|
||||
type Error = rustls::pki_types::InvalidDnsNameError;
|
||||
type Error = InvalidDnsNameError;
|
||||
|
||||
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: Arc::clone(&self.config).into(),
|
||||
})
|
||||
})
|
||||
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
make_tls_connect(&self.tls, hostname)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_tls_connect(
|
||||
tls: &Arc<rustls::ClientConfig>,
|
||||
hostname: &str,
|
||||
) -> Result<private::RustlsConnect, InvalidDnsNameError> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: tls.clone().into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
14
proxy/src/util.rs
Normal file
14
proxy/src/util.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use std::pin::pin;
|
||||
|
||||
use futures::future::{Either, select};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub async fn run_until_cancelled<F: Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Option<F::Output> {
|
||||
match select(pin!(f), pin!(cancellation_token.cancelled())).await {
|
||||
Either::Left((f, _)) => Some(f),
|
||||
Either::Right(((), _)) => None,
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user