Compare commits

..

8 Commits

100 changed files with 1381 additions and 2146 deletions

45
Cargo.lock generated
View File

@@ -753,7 +753,6 @@ dependencies = [
"axum",
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"headers",
"http 1.1.0",
@@ -762,8 +761,6 @@ dependencies = [
"mime",
"pin-project-lite",
"serde",
"serde_html_form",
"serde_path_to_error",
"tower 0.5.2",
"tower-layer",
"tower-service",
@@ -903,6 +900,12 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
[[package]]
name = "base64"
version = "0.21.7"
@@ -1294,7 +1297,7 @@ dependencies = [
"aws-smithy-types",
"axum",
"axum-extra",
"base64 0.22.1",
"base64 0.13.1",
"bytes",
"camino",
"cfg-if",
@@ -1420,7 +1423,7 @@ name = "control_plane"
version = "0.1.0"
dependencies = [
"anyhow",
"base64 0.22.1",
"base64 0.13.1",
"camino",
"clap",
"comfy-table",
@@ -2052,7 +2055,6 @@ dependencies = [
"axum-extra",
"camino",
"camino-tempfile",
"clap",
"futures",
"http-body-util",
"itertools 0.10.5",
@@ -4812,7 +4814,7 @@ dependencies = [
name = "postgres-protocol2"
version = "0.1.0"
dependencies = [
"base64 0.22.1",
"base64 0.20.0",
"byteorder",
"bytes",
"fallible-iterator",
@@ -5184,7 +5186,7 @@ dependencies = [
"aws-config",
"aws-sdk-iam",
"aws-sigv4",
"base64 0.22.1",
"base64 0.13.1",
"bstr",
"bytes",
"camino",
@@ -5271,6 +5273,7 @@ dependencies = [
"tokio-rustls 0.26.2",
"tokio-tungstenite 0.21.0",
"tokio-util",
"toml",
"tracing",
"tracing-log",
"tracing-opentelemetry",
@@ -6419,19 +6422,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "serde_html_form"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
dependencies = [
"form_urlencoded",
"indexmap 2.9.0",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_json"
version = "1.0.125"
@@ -6488,17 +6478,15 @@ dependencies = [
[[package]]
name = "serde_with"
version = "3.12.0"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa"
checksum = "07ff71d2c147a7b57362cead5e22f772cd52f6ab31cfcd9edcd7f6aeb2a0afbe"
dependencies = [
"base64 0.22.1",
"base64 0.13.1",
"chrono",
"hex",
"indexmap 1.9.3",
"indexmap 2.9.0",
"serde",
"serde_derive",
"serde_json",
"serde_with_macros",
"time",
@@ -6506,9 +6494,9 @@ dependencies = [
[[package]]
name = "serde_with_macros"
version = "3.12.0"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e"
checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f"
dependencies = [
"darling",
"proc-macro2",
@@ -8579,6 +8567,7 @@ dependencies = [
"anyhow",
"axum",
"axum-core",
"base64 0.13.1",
"base64 0.21.7",
"base64ct",
"bytes",

View File

@@ -71,8 +71,8 @@ aws-credential-types = "1.2.0"
aws-sigv4 = { version = "1.2", features = ["sign-http"] }
aws-types = "1.3"
axum = { version = "0.8.1", features = ["ws"] }
axum-extra = { version = "0.10.0", features = ["typed-header", "query"] }
base64 = "0.22"
axum-extra = { version = "0.10.0", features = ["typed-header"] }
base64 = "0.13.0"
bincode = "1.3"
bindgen = "0.71"
bit_field = "0.10.2"
@@ -171,7 +171,7 @@ sentry = { version = "0.37", default-features = false, features = ["backtrace",
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
serde_path_to_error = "0.1"
serde_with = { version = "3", features = [ "base64" ] }
serde_with = { version = "2.0", features = [ "base64" ] }
serde_assert = "0.5.0"
sha2 = "0.10.2"
signal-hook = "0.3"

View File

@@ -785,7 +785,7 @@ impl ComputeNode {
self.spawn_extension_stats_task();
if pspec.spec.autoprewarm {
self.prewarm_lfc(None);
self.prewarm_lfc();
}
Ok(())
}

View File

@@ -25,16 +25,11 @@ struct EndpointStoragePair {
}
const KEY: &str = "lfc_state";
impl EndpointStoragePair {
/// endpoint_id is set to None while prewarming from other endpoint, see replica promotion
/// If not None, takes precedence over pspec.spec.endpoint_id
fn from_spec_and_endpoint(
pspec: &crate::compute::ParsedSpec,
endpoint_id: Option<String>,
) -> Result<Self> {
let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref());
let Some(ref endpoint_id) = endpoint_id else {
bail!("pspec.endpoint_id missing, other endpoint_id not provided")
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
type Error = anyhow::Error;
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
bail!("pspec.endpoint_id missing")
};
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
bail!("pspec.endpoint_storage_addr missing")
@@ -89,7 +84,7 @@ impl ComputeNode {
}
/// Returns false if there is a prewarm request ongoing, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -102,7 +97,7 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
let Err(err) = cloned.prewarm_impl().await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
@@ -114,14 +109,13 @@ impl ComputeNode {
true
}
/// from_endpoint: None for endpoint managed by this compute_ctl
fn endpoint_storage_pair(&self, from_endpoint: Option<String>) -> Result<EndpointStoragePair> {
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
let state = self.state.lock().unwrap();
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
state.pspec.as_ref().unwrap().try_into()
}
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
async fn prewarm_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
@@ -179,7 +173,7 @@ impl ComputeNode {
}
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();

View File

@@ -2,7 +2,6 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress;
use crate::http::JsonResponse;
use axum::response::{IntoResponse, Response};
use axum::{Json, http::StatusCode};
use axum_extra::extract::OptionalQuery;
use compute_api::responses::LfcOffloadState;
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
@@ -17,16 +16,8 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadS
Json(compute.lfc_offload_state())
}
#[derive(serde::Deserialize)]
pub struct PrewarmQuery {
pub from_endpoint: String,
}
pub(in crate::http) async fn prewarm(
compute: Compute,
OptionalQuery(query): OptionalQuery<PrewarmQuery>,
) -> Response {
if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) {
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
if compute.prewarm_lfc() {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(

View File

@@ -45,8 +45,6 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
@@ -166,7 +164,7 @@ impl ComputeControlPlane {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::EdDSA),
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
key_id: Some(base64::encode_config(key_hash, base64::URL_SAFE_NO_PAD)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
@@ -175,7 +173,7 @@ impl ComputeControlPlane {
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
key_type: OctetKeyPairType::OctetKeyPair,
curve: EllipticCurve::Ed25519,
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
x: base64::encode_config(public_key, base64::URL_SAFE_NO_PAD),
}),
}],
})

View File

@@ -8,7 +8,6 @@ anyhow.workspace = true
axum-extra.workspace = true
axum.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
jsonwebtoken.workspace = true
prometheus.workspace = true

View File

@@ -3,8 +3,7 @@
//! This service is deployed either as a separate component or as part of compute image
//! for large computes.
mod app;
use anyhow::Context;
use clap::Parser;
use anyhow::{Context, bail};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::logging;
@@ -18,18 +17,6 @@ const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
config_file: Option<String>,
#[arg(long, default_value = "false", requires = "config")]
/// to allow testing k8s helm chart where we don't have s3 credentials
no_s3_check_on_startup: bool,
#[arg(long, value_name = "FILE")]
/// inline config mode for k8s helm chart
config: Option<String>,
}
#[derive(serde::Deserialize)]
#[serde(tag = "type")]
struct Config {
@@ -50,16 +37,19 @@ async fn main() -> anyhow::Result<()> {
logging::Output::Stdout,
)?;
let args = Args::parse();
let config: Config = if let Some(config_path) = args.config_file {
info!("Reading config from {config_path}");
let config = std::fs::read_to_string(config_path)?;
// Allow either passing filename or inline config (for k8s helm chart)
let args: Vec<String> = std::env::args().skip(1).collect();
let config: Config = if args.len() == 1 && args[0].ends_with(".json") {
info!("Reading config from {}", args[0]);
let config = std::fs::read_to_string(args[0].clone())?;
serde_json::from_str(&config).context("parsing config")?
} else if let Some(config) = args.config {
} else if !args.is_empty() && args[0].starts_with("--config=") {
info!("Reading inline config");
serde_json::from_str(&config).context("parsing config")?
let config = args.join(" ");
let config = config.strip_prefix("--config=").unwrap();
serde_json::from_str(config).context("parsing config")?
} else {
anyhow::bail!("Supply either config file path or --config=inline-config");
bail!("Usage: endpoint_storage config.json or endpoint_storage --config=JSON");
};
info!("Reading pemfile from {}", config.pemfile.clone());
@@ -72,9 +62,7 @@ async fn main() -> anyhow::Result<()> {
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
let cancel = tokio_util::sync::CancellationToken::new();
if !args.no_s3_check_on_startup {
app::check_storage_permissions(&storage, cancel.clone()).await?;
}
app::check_storage_permissions(&storage, cancel.clone()).await?;
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
auth,

View File

@@ -9,7 +9,7 @@ use utils::id::{NodeId, TimelineId};
use crate::controller_api::NodeRegisterRequest;
use crate::models::{LocationConfigMode, ShardImportStatus};
use crate::shard::{ShardStripeSize, TenantShardId};
use crate::shard::TenantShardId;
/// Upcall message sent by the pageserver to the configured `control_plane_api` on
/// startup.
@@ -36,10 +36,6 @@ pub struct ReAttachResponseTenant {
/// Default value only for backward compat: this field should be set
#[serde(default = "default_mode")]
pub mode: LocationConfigMode,
// Default value only for backward compat: this field should be set
#[serde(default = "ShardStripeSize::default")]
pub stripe_size: ShardStripeSize,
}
#[derive(Serialize, Deserialize)]
pub struct ReAttachResponse {

View File

@@ -55,16 +55,9 @@ impl FeatureResolverBackgroundLoop {
continue;
}
};
let project_id = this.posthog_client.config.project_id.parse::<u64>().ok();
match FeatureStore::new_with_flags(resp.flags, project_id) {
Ok(feature_store) => {
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
}
Err(e) => {
tracing::warn!("Cannot process feature flag spec: {}", e);
}
}
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
}
tracing::info!("PostHog feature resolver stopped");
}

View File

@@ -39,9 +39,6 @@ pub struct LocalEvaluationResponse {
#[derive(Deserialize)]
pub struct LocalEvaluationFlag {
#[allow(dead_code)]
id: u64,
team_id: u64,
key: String,
filters: LocalEvaluationFlagFilters,
active: bool,
@@ -110,32 +107,17 @@ impl FeatureStore {
}
}
pub fn new_with_flags(
flags: Vec<LocalEvaluationFlag>,
project_id: Option<u64>,
) -> Result<Self, &'static str> {
pub fn new_with_flags(flags: Vec<LocalEvaluationFlag>) -> Self {
let mut store = Self::new();
store.set_flags(flags, project_id)?;
Ok(store)
store.set_flags(flags);
store
}
pub fn set_flags(
&mut self,
flags: Vec<LocalEvaluationFlag>,
project_id: Option<u64>,
) -> Result<(), &'static str> {
pub fn set_flags(&mut self, flags: Vec<LocalEvaluationFlag>) {
self.flags.clear();
for flag in flags {
if let Some(project_id) = project_id {
if flag.team_id != project_id {
return Err(
"Retrieved a spec with different project id, wrong config? Discarding the feature flags.",
);
}
}
self.flags.insert(flag.key.clone(), flag);
}
Ok(())
}
/// Generate a consistent hash for a user ID (e.g., tenant ID).
@@ -552,13 +534,6 @@ impl PostHogClient {
})
}
/// Check if the server API key is a feature flag secure API key. This key can only be
/// used to fetch the feature flag specs and can only be used on a undocumented API
/// endpoint.
fn is_feature_flag_secure_api_key(&self) -> bool {
self.config.server_api_key.starts_with("phs_")
}
/// Fetch the feature flag specs from the server.
///
/// This is unfortunately an undocumented API at:
@@ -572,22 +547,10 @@ impl PostHogClient {
) -> anyhow::Result<LocalEvaluationResponse> {
// BASE_URL/api/projects/:project_id/feature_flags/local_evaluation
// with bearer token of self.server_api_key
// OR
// BASE_URL/api/feature_flag/local_evaluation/
// with bearer token of feature flag specific self.server_api_key
let url = if self.is_feature_flag_secure_api_key() {
// The new feature local evaluation secure API token
format!(
"{}/api/feature_flag/local_evaluation",
self.config.private_api_url
)
} else {
// The old personal API token
format!(
"{}/api/projects/{}/feature_flags/local_evaluation",
self.config.private_api_url, self.config.project_id
)
};
let url = format!(
"{}/api/projects/{}/feature_flags/local_evaluation",
self.config.private_api_url, self.config.project_id
);
let response = self
.client
.get(url)
@@ -840,7 +803,7 @@ mod tests {
fn evaluate_multivariate() {
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags, None).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant =
@@ -910,7 +873,7 @@ mod tests {
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags, None).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new());
@@ -966,7 +929,7 @@ mod tests {
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags, None).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant =

View File

@@ -5,7 +5,7 @@ edition = "2024"
license = "MIT/Apache-2.0"
[dependencies]
base64.workspace = true
base64 = "0.20"
byteorder.workspace = true
bytes.workspace = true
fallible-iterator.workspace = true

View File

@@ -3,8 +3,6 @@
use std::fmt::Write;
use std::{io, iter, mem, str};
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use rand::{self, Rng};
use sha2::digest::FixedOutput;
@@ -228,7 +226,7 @@ impl ScramSha256 {
let (client_key, server_key) = match password {
Credentials::Password(password) => {
let salt = match BASE64_STANDARD.decode(parsed.salt) {
let salt = match base64::decode(parsed.salt) {
Ok(salt) => salt,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
@@ -257,7 +255,7 @@ impl ScramSha256 {
let mut cbind_input = vec![];
cbind_input.extend(channel_binding.gs2_header().as_bytes());
cbind_input.extend(channel_binding.cbind_data());
let cbind_input = BASE64_STANDARD.encode(&cbind_input);
let cbind_input = base64::encode(&cbind_input);
self.message.clear();
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
@@ -274,12 +272,7 @@ impl ScramSha256 {
*proof ^= signature;
}
write!(
&mut self.message,
",p={}",
BASE64_STANDARD.encode(client_proof)
)
.unwrap();
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
self.state = State::Finish {
server_key,
@@ -313,7 +306,7 @@ impl ScramSha256 {
ServerFinalMessage::Verifier(verifier) => verifier,
};
let verifier = match BASE64_STANDARD.decode(verifier) {
let verifier = match base64::decode(verifier) {
Ok(verifier) => verifier,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};

View File

@@ -6,8 +6,6 @@
//! side. This is good because it ensures the cleartext password won't
//! end up in logs pg_stat displays, etc.
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::digest::FixedOutput;
@@ -85,8 +83,8 @@ pub(crate) async fn scram_sha_256_salt(
format!(
"SCRAM-SHA-256${}:{}${}:{}",
SCRAM_DEFAULT_ITERATIONS,
BASE64_STANDARD.encode(salt),
BASE64_STANDARD.encode(stored_key),
BASE64_STANDARD.encode(server_key)
base64::encode(salt),
base64::encode(stored_key),
base64::encode(server_key)
)
}

View File

@@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::{env, io};
use anyhow::{Context, Result, anyhow};
use anyhow::{Context, Result};
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
use azure_storage::StorageCredentials;
@@ -37,7 +37,6 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
use crate::{
ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
Version, VersionKind,
};
pub struct AzureBlobStorage {
@@ -406,39 +405,6 @@ impl AzureBlobStorage {
pub fn container_name(&self) -> &str {
&self.container_name
}
async fn list_versions_with_permit(
&self,
_permit: &tokio::sync::SemaphorePermit<'_>,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
// We do not return this info back to `VersionListing` yet.
builder = builder.include_deleted(true);
builder
};
let kind = RequestKind::ListVersions;
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
}
trait ListingCollector {
@@ -522,10 +488,27 @@ impl RemoteStorage for AzureBlobStorage {
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> std::result::Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
builder
};
let kind = RequestKind::ListVersions;
let permit = self.permit(kind, cancel).await?;
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
.await
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
async fn head_object(
@@ -820,159 +803,14 @@ impl RemoteStorage for AzureBlobStorage {
async fn time_travel_recover(
&self,
prefix: Option<&RemotePath>,
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
_complexity_limit: Option<NonZeroU32>,
_prefix: Option<&RemotePath>,
_timestamp: SystemTime,
_done_if_after: SystemTime,
_cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
.to_string()
+ "for some specific files. If a file gets deleted but then overwritten and we want to recover "
+ "to the time during the file was not present, this functionality will recover the file. Only "
+ "use the functionality for services that can tolerate this. For example, recovering a state of the "
+ "pageserver tenants.";
tracing::error!("{}", msg);
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, None, cancel)
.await
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
DownloadError::Cancelled => TimeTravelError::Cancelled,
other => TimeTravelError::Other(other.into()),
})?;
let versions_and_deletes = version_listing.versions;
tracing::info!(
"Built list for time travel with {} versions and deletions",
versions_and_deletes.len()
);
// Work on the list of references instead of the objects directly,
// otherwise we get lifetime errors in the sort_by_key call below.
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
for vd in &versions_and_deletes {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"
)));
}
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
vds_for_key.entry(key).or_default().push(vd);
}
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
for (key, versions) in vds_for_key {
let last_vd = versions.last().unwrap();
let key = self.relative_path_to_name(key);
if last_vd.last_modified > done_if_after {
tracing::debug!("Key {key} has version later than done_if_after, skipping");
continue;
}
// the version we want to restore to.
let version_to_restore_to =
match versions.binary_search_by_key(&timestamp, |tpl| tpl.last_modified) {
Ok(v) => v,
Err(e) => e,
};
if version_to_restore_to == versions.len() {
tracing::debug!("Key {key} has no changes since timestamp, skipping");
continue;
}
let mut do_delete = false;
if version_to_restore_to == 0 {
// All versions more recent, so the key didn't exist at the specified time point.
tracing::debug!(
"All {} versions more recent for {key}, deleting",
versions.len()
);
do_delete = true;
} else {
match &versions[version_to_restore_to - 1] {
Version {
kind: VersionKind::Version(version_id),
..
} => {
let source_url = format!(
"{}/{}?versionid={}",
self.client
.url()
.map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
key,
version_id.0
);
tracing::debug!(
"Promoting old version {} for {key} at {}...",
version_id.0,
source_url
);
backoff::retry(
|| async {
let blob_client = self.client.blob_client(key.clone());
let op = blob_client.copy(Url::from_str(&source_url).unwrap());
tokio::select! {
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
}
},
is_permanent,
warn_threshold,
max_retries,
"copying object version for time_travel_recover",
cancel,
)
.await
.ok_or_else(|| TimeTravelError::Cancelled)
.and_then(|x| x)?;
tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
}
Version {
kind: VersionKind::DeletionMarker,
..
} => {
do_delete = true;
}
}
};
if do_delete {
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
// Key has since been deleted (but there was some history), no need to do anything
tracing::debug!("Key {key} already deleted, skipping.");
} else {
tracing::debug!("Deleting {key}...");
self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
.await
.map_err(|e| {
// delete_oid0 will use TimeoutOrCancel
if TimeoutOrCancel::caused_by_cancel(&e) {
TimeTravelError::Cancelled
} else {
TimeTravelError::Other(e)
}
})?;
}
}
}
Ok(())
// TODO use Azure point in time recovery feature for this
// https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
Err(TimeTravelError::Unimplemented)
}
}

View File

@@ -440,7 +440,6 @@ pub trait RemoteStorage: Send + Sync + 'static {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError>;
}
@@ -652,23 +651,22 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
match self {
Self::LocalFs(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
Self::AwsS3(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
Self::AzureBlob(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
Self::Unreliable(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
}

View File

@@ -610,7 +610,6 @@ impl RemoteStorage for LocalFs {
_timestamp: SystemTime,
_done_if_after: SystemTime,
_cancel: &CancellationToken,
_complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
Err(TimeTravelError::Unimplemented)
}

View File

@@ -981,16 +981,22 @@ impl RemoteStorage for S3Bucket {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
tracing::trace!("Target time: {timestamp:?}, done_if_after {done_if_after:?}");
// Limit the number of versions deletions, mostly so that we don't
// keep requesting forever if the list is too long, as we'd put the
// list in RAM.
// Building a list of 100k entries that reaches the limit roughly takes
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
const COMPLEXITY_LIMIT: Option<NonZeroU32> = NonZeroU32::new(100_000);
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, complexity_limit, cancel)
.list_versions_with_permit(&permit, prefix, mode, COMPLEXITY_LIMIT, cancel)
.await
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
@@ -1016,7 +1022,6 @@ impl RemoteStorage for S3Bucket {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
// TODO: check the behavior of using the SDK on a non-versioned container
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"

View File

@@ -240,12 +240,11 @@ impl RemoteStorage for UnreliableWrapper {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned())))
.map_err(TimeTravelError::Other)?;
self.inner
.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
}

View File

@@ -157,7 +157,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
// No changes after recovery to t2 (no-op)
let t_final = time_point().await;
ctx.client
.time_travel_recover(None, t2, t_final, &cancel, None)
.time_travel_recover(None, t2, t_final, &cancel)
.await?;
let t2_files_recovered = list_files(&ctx.client, &cancel).await?;
println!("after recovery to t2: {t2_files_recovered:?}");
@@ -173,7 +173,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
// after recovery to t1: path1 is back, path2 has the old content
let t_final = time_point().await;
ctx.client
.time_travel_recover(None, t1, t_final, &cancel, None)
.time_travel_recover(None, t1, t_final, &cancel)
.await?;
let t1_files_recovered = list_files(&ctx.client, &cancel).await?;
println!("after recovery to t1: {t1_files_recovered:?}");
@@ -189,7 +189,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
// after recovery to t0: everything is gone except for path1
let t_final = time_point().await;
ctx.client
.time_travel_recover(None, t0, t_final, &cancel, None)
.time_travel_recover(None, t0, t_final, &cancel)
.await?;
let t0_files_recovered = list_files(&ctx.client, &cancel).await?;
println!("after recovery to t0: {t0_files_recovered:?}");

View File

@@ -176,11 +176,9 @@ async fn main() -> anyhow::Result<()> {
let config = RemoteStorageConfig::from_toml_str(&cmd.config_toml_str)?;
let storage = remote_storage::GenericRemoteStorage::from_config(&config).await;
let cancel = CancellationToken::new();
// Complexity limit: as we are running this command locally, we should have a lot of memory available, and we do not
// need to limit the number of versions we are going to delete.
storage
.unwrap()
.time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel, None)
.time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel)
.await?;
}
Commands::Key(dkc) => dkc.execute(),

View File

@@ -1,6 +1,5 @@
use std::{collections::HashMap, sync::Arc};
use anyhow::Context;
use async_compression::tokio::write::GzipEncoder;
use camino::{Utf8Path, Utf8PathBuf};
use metrics::core::{AtomicU64, GenericCounter};
@@ -168,17 +167,14 @@ impl BasebackupCache {
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
}
fn tmp_dir(&self) -> Utf8PathBuf {
self.data_dir.join("tmp")
}
fn entry_tmp_path(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Utf8PathBuf {
self.tmp_dir()
self.data_dir
.join("tmp")
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
}
@@ -198,18 +194,15 @@ impl BasebackupCache {
Some((tenant_id, timeline_id, lsn))
}
// Recreate the tmp directory to clear all files in it.
async fn clean_tmp_dir(&self) -> anyhow::Result<()> {
let tmp_dir = self.tmp_dir();
if tmp_dir.exists() {
tokio::fs::remove_dir_all(&tmp_dir).await?;
}
tokio::fs::create_dir_all(&tmp_dir).await?;
Ok(())
}
async fn cleanup(&self) -> anyhow::Result<()> {
self.clean_tmp_dir().await?;
// Cleanup tmp directory.
let tmp_dir = self.data_dir.join("tmp");
let mut tmp_dir = tokio::fs::read_dir(&tmp_dir).await?;
while let Some(dir_entry) = tmp_dir.next_entry().await? {
if let Err(e) = tokio::fs::remove_file(dir_entry.path()).await {
tracing::warn!("Failed to remove basebackup cache tmp file: {:#}", e);
}
}
// Remove outdated entries.
let entries_old = self.entries.lock().unwrap().clone();
@@ -248,14 +241,16 @@ impl BasebackupCache {
}
async fn on_startup(&self) -> anyhow::Result<()> {
// Create data_dir if it does not exist.
tokio::fs::create_dir_all(&self.data_dir)
// Create data_dir and tmp directory if they do not exist.
tokio::fs::create_dir_all(&self.data_dir.join("tmp"))
.await
.context("Failed to create basebackup cache data directory")?;
self.clean_tmp_dir()
.await
.context("Failed to clean tmp directory")?;
.map_err(|e| {
anyhow::anyhow!(
"Failed to create basebackup cache data_dir {:?}: {:?}",
self.data_dir,
e
)
})?;
// Read existing entries from the data_dir and add them to in-memory state.
let mut entries = HashMap::new();
@@ -456,11 +451,6 @@ impl BasebackupCache {
}
// Move the tmp file to the final location atomically.
// The tmp file is fsynced, so it's guaranteed that we will not have a partial file
// in the main directory.
// It's not necessary to fsync the inode after renaming, because the worst case is that
// the rename operation will be rolled back on the disk failure, the entry will disappear
// from the main directory, and the entry access will cause a cache miss.
let entry_path = self.entry_path(tenant_shard_id.tenant_id, timeline_id, req_lsn);
tokio::fs::rename(&entry_tmp_path, &entry_path).await?;
@@ -478,17 +468,16 @@ impl BasebackupCache {
}
/// Prepares a basebackup in a temporary file.
/// Guarantees that the tmp file is fsynced before returning.
async fn prepare_basebackup_tmp(
&self,
entry_tmp_path: &Utf8Path,
emptry_tmp_path: &Utf8Path,
timeline: &Arc<Timeline>,
req_lsn: Lsn,
) -> anyhow::Result<()> {
let ctx = RequestContext::new(TaskKind::BasebackupCache, DownloadBehavior::Download);
let ctx = ctx.with_scope_timeline(timeline);
let file = tokio::fs::File::create(entry_tmp_path).await?;
let file = tokio::fs::File::create(emptry_tmp_path).await?;
let mut writer = BufWriter::new(file);
let mut encoder = GzipEncoder::with_quality(

View File

@@ -573,8 +573,7 @@ fn start_pageserver(
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
let tenant_manager = mgr::init(
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
conf,
background_purges.clone(),
TenantSharedResources {
@@ -585,10 +584,10 @@ fn start_pageserver(
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
);
))?;
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),

View File

@@ -1,6 +1,5 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
@@ -87,35 +86,7 @@ impl FeatureResolver {
}
}
}
// TODO: move this to a background task so that we don't block startup in case of slow disk
let metadata_path = conf.metadata_path();
match std::fs::read_to_string(&metadata_path) {
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
Ok(metadata) => {
properties.insert(
"hostname".to_string(),
PostHogFlagFilterPropertyValue::String(metadata.http_host),
);
if let Some(cplane_region) = metadata.other.get("region_id") {
if let Some(cplane_region) = cplane_region.as_str() {
// This region contains the cell number
properties.insert(
"neon_region".to_string(),
PostHogFlagFilterPropertyValue::String(
cplane_region.to_string(),
),
);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse metadata.json: {}", e);
}
},
Err(e) => {
tracing::warn!("Failed to read metadata.json: {}", e);
}
}
// TODO: add pageserver URL.
Arc::new(properties)
};
let fake_tenants = {

View File

@@ -73,7 +73,6 @@ use crate::tenant::remote_timeline_client::{
use crate::tenant::secondary::SecondaryController;
use crate::tenant::size::ModelInputs;
use crate::tenant::storage_layer::{IoConcurrency, LayerAccessStatsReset, LayerName};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::timeline::offload::{OffloadError, offload_timeline};
use crate::tenant::timeline::{
CompactFlags, CompactOptions, CompactRequest, CompactionError, MarkInvisibleRequest, Timeline,
@@ -1452,10 +1451,7 @@ async fn timeline_layer_scan_disposable_keys(
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download)
.with_scope_timeline(&timeline);
let guard = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = timeline.layers.read().await;
let Some(layer) = guard.try_get_from_key(&layer_name.clone().into()) else {
return Err(ApiError::NotFound(
anyhow::anyhow!("Layer {tenant_shard_id}/{timeline_id}/{layer_name} not found").into(),

View File

@@ -1053,15 +1053,6 @@ pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("Failed to register pageserver_tenant_states_count metric")
});
pub(crate) static TIMELINE_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_timeline_states_count",
"Count of timelines per state",
&["state"]
)
.expect("Failed to register pageserver_timeline_states_count metric")
});
/// A set of broken tenants.
///
/// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken
@@ -3334,8 +3325,6 @@ impl TimelineMetrics {
&timeline_id,
);
TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc();
TimelineMetrics {
tenant_id,
shard_id,
@@ -3490,8 +3479,6 @@ impl TimelineMetrics {
return;
}
TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec();
let tenant_id = &self.tenant_id;
let timeline_id = &self.timeline_id;
let shard_id = &self.shard_id;

View File

@@ -51,7 +51,6 @@ use secondary::heatmap::{HeatMapTenant, HeatMapTimeline};
use storage_broker::BrokerClientChannel;
use timeline::compaction::{CompactionOutcome, GcCompactionQueue};
use timeline::import_pgdata::ImportingTimeline;
use timeline::layer_manager::LayerManagerLockHolder;
use timeline::offload::{OffloadError, offload_timeline};
use timeline::{
CompactFlags, CompactOptions, CompactionError, PreviousHeatmap, ShutdownMode, import_pgdata,
@@ -90,8 +89,7 @@ use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC,
remove_tenant_metrics,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
};
use crate::task_mgr::TaskKind;
use crate::tenant::config::LocationMode;
@@ -546,28 +544,6 @@ pub struct OffloadedTimeline {
/// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it
pub deleted_from_ancestor: AtomicBool,
_metrics_guard: OffloadedTimelineMetricsGuard,
}
/// Increases the offloaded timeline count metric when created, and decreases when dropped.
struct OffloadedTimelineMetricsGuard;
impl OffloadedTimelineMetricsGuard {
fn new() -> Self {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.inc();
Self
}
}
impl Drop for OffloadedTimelineMetricsGuard {
fn drop(&mut self) {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.dec();
}
}
impl OffloadedTimeline {
@@ -600,8 +576,6 @@ impl OffloadedTimeline {
delete_progress: timeline.delete_progress.clone(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
})
}
fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self {
@@ -621,7 +595,6 @@ impl OffloadedTimeline {
archived_at,
delete_progress: TimelineDeleteProgress::default(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
}
}
fn manifest(&self) -> OffloadedTimelineManifest {
@@ -1316,7 +1289,7 @@ impl TenantShard {
ancestor.is_some()
|| timeline
.layers
.read(LayerManagerLockHolder::LoadLayerMap)
.read()
.await
.layer_map()
.expect(
@@ -2644,7 +2617,7 @@ impl TenantShard {
}
let layer_names = tline
.layers
.read(LayerManagerLockHolder::Testing)
.read()
.await
.layer_map()
.unwrap()
@@ -3159,12 +3132,7 @@ impl TenantShard {
for timeline in &compact {
// Collect L0 counts. Can't await while holding lock above.
if let Ok(lm) = timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await
.layer_map()
{
if let Ok(lm) = timeline.layers.read().await.layer_map() {
l0_counts.insert(timeline.timeline_id, lm.level0_deltas().len());
}
}
@@ -4906,7 +4874,7 @@ impl TenantShard {
}
let layer_names = tline
.layers
.read(LayerManagerLockHolder::Testing)
.read()
.await
.layer_map()
.unwrap()
@@ -6976,7 +6944,7 @@ mod tests {
.await?;
make_some_layers(tline.as_ref(), Lsn(0x20), &ctx).await?;
let layer_map = tline.layers.read(LayerManagerLockHolder::Testing).await;
let layer_map = tline.layers.read().await;
let level0_deltas = layer_map
.layer_map()?
.level0_deltas()
@@ -7212,7 +7180,7 @@ mod tests {
let lsn = Lsn(0x10);
let inserted = bulk_insert_compact_gc(&tenant, &tline, &ctx, lsn, 50, 10000).await?;
let guard = tline.layers.read(LayerManagerLockHolder::Testing).await;
let guard = tline.layers.read().await;
let lm = guard.layer_map()?;
lm.dump(true, &ctx).await?;
@@ -8240,23 +8208,12 @@ mod tests {
tline.freeze_and_flush().await?; // force create a delta layer
}
let before_num_l0_delta_files = tline
.layers
.read(LayerManagerLockHolder::Testing)
.await
.layer_map()?
.level0_deltas()
.len();
let before_num_l0_delta_files =
tline.layers.read().await.layer_map()?.level0_deltas().len();
tline.compact(&cancel, EnumSet::default(), &ctx).await?;
let after_num_l0_delta_files = tline
.layers
.read(LayerManagerLockHolder::Testing)
.await
.layer_map()?
.level0_deltas()
.len();
let after_num_l0_delta_files = tline.layers.read().await.layer_map()?.level0_deltas().len();
assert!(
after_num_l0_delta_files < before_num_l0_delta_files,

View File

@@ -61,8 +61,8 @@ pub(crate) struct LocationConf {
/// The detailed shard identity. This structure is already scoped within
/// a TenantShardId, but we need the full ShardIdentity to enable calculating
/// key->shard mappings.
// TODO(vlad): Remove this default once all configs have a shard identity on disk.
#[serde(default = "ShardIdentity::unsharded")]
#[serde(skip_serializing_if = "ShardIdentity::is_unsharded")]
pub(crate) shard: ShardIdentity,
/// The pan-cluster tenant configuration, the same on all locations
@@ -149,12 +149,7 @@ impl LocationConf {
/// For use when attaching/re-attaching: update the generation stored in this
/// structure. If we were in a secondary state, promote to attached (posession
/// of a fresh generation implies this).
pub(crate) fn attach_in_generation(
&mut self,
mode: AttachmentMode,
generation: Generation,
stripe_size: ShardStripeSize,
) {
pub(crate) fn attach_in_generation(&mut self, mode: AttachmentMode, generation: Generation) {
match &mut self.mode {
LocationMode::Attached(attach_conf) => {
attach_conf.generation = generation;
@@ -168,8 +163,6 @@ impl LocationConf {
})
}
}
self.shard.stripe_size = stripe_size;
}
pub(crate) fn try_from(conf: &'_ models::LocationConfig) -> anyhow::Result<Self> {

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
//! Helper functions to upload files to remote storage with a RemoteStorage
use std::io::{ErrorKind, SeekFrom};
use std::num::NonZeroU32;
use std::time::SystemTime;
use anyhow::{Context, bail};
@@ -229,25 +228,11 @@ pub(crate) async fn time_travel_recover_tenant(
let timelines_path = super::remote_timelines_path(tenant_shard_id);
prefixes.push(timelines_path);
}
// Limit the number of versions deletions, mostly so that we don't
// keep requesting forever if the list is too long, as we'd put the
// list in RAM.
// Building a list of 100k entries that reaches the limit roughly takes
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
const COMPLEXITY_LIMIT: Option<NonZeroU32> = NonZeroU32::new(100_000);
for prefix in &prefixes {
backoff::retry(
|| async {
storage
.time_travel_recover(
Some(prefix),
timestamp,
done_if_after,
cancel,
COMPLEXITY_LIMIT,
)
.time_travel_recover(Some(prefix), timestamp, done_if_after, cancel)
.await
},
|e| !matches!(e, TimeTravelError::Other(_)),

View File

@@ -1635,7 +1635,6 @@ pub(crate) mod test {
use crate::tenant::disk_btree::tests::TestDisk;
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
use crate::tenant::storage_layer::{Layer, ResidentLayer};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::{TenantShard, Timeline};
/// Construct an index for a fictional delta layer and and then
@@ -2003,7 +2002,7 @@ pub(crate) mod test {
let initdb_layer = timeline
.layers
.read(crate::tenant::timeline::layer_manager::LayerManagerLockHolder::Testing)
.read()
.await
.likely_resident_layers()
.next()
@@ -2079,7 +2078,7 @@ pub(crate) mod test {
let new_layer = timeline
.layers
.read(LayerManagerLockHolder::Testing)
.read()
.await
.likely_resident_layers()
.find(|&x| x != &initdb_layer)

View File

@@ -10,7 +10,6 @@ use super::*;
use crate::context::DownloadBehavior;
use crate::tenant::harness::{TenantHarness, test_img};
use crate::tenant::storage_layer::{IoConcurrency, LayerVisibilityHint};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
/// Used in tests to advance a future to wanted await point, and not futher.
const ADVANCE: std::time::Duration = std::time::Duration::from_secs(3600);
@@ -60,7 +59,7 @@ async fn smoke_test() {
// there to avoid the timeline being illegally empty
let (layer, dummy_layer) = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -216,7 +215,7 @@ async fn smoke_test() {
// Simulate GC removing our test layer.
{
let mut g = timeline.layers.write(LayerManagerLockHolder::Testing).await;
let mut g = timeline.layers.write().await;
let layers = &[layer];
g.open_mut().unwrap().finish_gc_timeline(layers);
@@ -262,7 +261,7 @@ async fn evict_and_wait_on_wanted_deleted() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -306,7 +305,7 @@ async fn evict_and_wait_on_wanted_deleted() {
// assert that once we remove the `layer` from the layer map and drop our reference,
// the deletion of the layer in remote_storage happens.
{
let mut layers = timeline.layers.write(LayerManagerLockHolder::Testing).await;
let mut layers = timeline.layers.write().await;
layers.open_mut().unwrap().finish_gc_timeline(&[layer]);
}
@@ -348,7 +347,7 @@ fn read_wins_pending_eviction() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -481,7 +480,7 @@ fn multiple_pending_evictions_scenario(name: &'static str, in_order: bool) {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -656,7 +655,7 @@ async fn cancelled_get_or_maybe_download_does_not_cancel_eviction() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -742,7 +741,7 @@ async fn evict_and_wait_does_not_wait_for_download() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -863,7 +862,7 @@ async fn eviction_cancellation_on_drop() {
let (evicted_layer, not_evicted) = {
let mut layers = {
let mut guard = timeline.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = timeline.layers.write().await;
let layers = guard.likely_resident_layers().cloned().collect::<Vec<_>>();
// remove the layers from layermap
guard.open_mut().unwrap().finish_gc_timeline(&layers);

View File

@@ -35,11 +35,7 @@ use fail::fail_point;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use handle::ShardTimelineId;
use layer_manager::{
LayerManagerLockHolder, LayerManagerReadGuard, LayerManagerWriteGuard, LockedLayerManager,
Shutdown,
};
use layer_manager::Shutdown;
use offload::OffloadError;
use once_cell::sync::Lazy;
use pageserver_api::config::tenant_conf_defaults::DEFAULT_PITR_INTERVAL;
@@ -86,6 +82,7 @@ use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
use self::delete::DeleteTimelineFlow;
pub(super) use self::eviction_task::EvictionTaskTenantState;
use self::eviction_task::EvictionTaskTimelineState;
use self::layer_manager::LayerManager;
use self::logical_size::LogicalSize;
use self::walreceiver::{WalReceiver, WalReceiverConf};
use super::remote_timeline_client::RemoteTimelineClient;
@@ -184,13 +181,13 @@ impl std::fmt::Display for ImageLayerCreationMode {
/// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things.
/// Can be removed after all refactors are done.
fn drop_layer_manager_rlock(rlock: LayerManagerReadGuard<'_>) {
fn drop_rlock<T>(rlock: tokio::sync::RwLockReadGuard<T>) {
drop(rlock)
}
/// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things.
/// Can be removed after all refactors are done.
fn drop_layer_manager_wlock(rlock: LayerManagerWriteGuard<'_>) {
fn drop_wlock<T>(rlock: tokio::sync::RwLockWriteGuard<'_, T>) {
drop(rlock)
}
@@ -244,7 +241,7 @@ pub struct Timeline {
///
/// In the future, we'll be able to split up the tuple of LayerMap and `LayerFileManager`,
/// so that e.g. on-demand-download/eviction, and layer spreading, can operate just on `LayerFileManager`.
pub(crate) layers: LockedLayerManager,
pub(crate) layers: tokio::sync::RwLock<LayerManager>,
last_freeze_at: AtomicLsn,
// Atomic would be more appropriate here.
@@ -1538,10 +1535,7 @@ impl Timeline {
/// This method makes no distinction between local and remote layers.
/// Hence, the result **does not represent local filesystem usage**.
pub(crate) async fn layer_size_sum(&self) -> u64 {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
guard.layer_size_sum()
}
@@ -1851,7 +1845,7 @@ impl Timeline {
// time, and this was missed.
// if write_guard.is_none() { return; }
let Ok(layers_guard) = self.layers.try_read(LayerManagerLockHolder::TryFreezeLayer) else {
let Ok(layers_guard) = self.layers.try_read() else {
// Don't block if the layer lock is busy
return;
};
@@ -2164,7 +2158,7 @@ impl Timeline {
if let ShutdownMode::FreezeAndFlush = mode {
let do_flush = if let Some((open, frozen)) = self
.layers
.read(LayerManagerLockHolder::Shutdown)
.read()
.await
.layer_map()
.map(|lm| (lm.open_layer.is_some(), lm.frozen_layers.len()))
@@ -2268,10 +2262,7 @@ impl Timeline {
// Allow any remaining in-memory layers to do cleanup -- until that, they hold the gate
// open.
let mut write_guard = self.write_lock.lock().await;
self.layers
.write(LayerManagerLockHolder::Shutdown)
.await
.shutdown(&mut write_guard);
self.layers.write().await.shutdown(&mut write_guard);
}
// Finally wait until any gate-holders are complete.
@@ -2374,10 +2365,7 @@ impl Timeline {
&self,
reset: LayerAccessStatsReset,
) -> Result<LayerMapInfo, layer_manager::Shutdown> {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
let mut in_memory_layers = Vec::with_capacity(layer_map.frozen_layers.len() + 1);
if let Some(open_layer) = &layer_map.open_layer {
@@ -3244,7 +3232,7 @@ impl Timeline {
/// Initialize with an empty layer map. Used when creating a new timeline.
pub(super) fn init_empty_layer_map(&self, start_lsn: Lsn) {
let mut layers = self.layers.try_write(LayerManagerLockHolder::Init).expect(
let mut layers = self.layers.try_write().expect(
"in the context where we call this function, no other task has access to the object",
);
layers
@@ -3264,10 +3252,7 @@ impl Timeline {
use init::Decision::*;
use init::{Discovered, DismissedLayer};
let mut guard = self
.layers
.write(LayerManagerLockHolder::LoadLayerMap)
.await;
let mut guard = self.layers.write().await;
let timer = self.metrics.load_layer_map_histo.start_timer();
@@ -3884,10 +3869,7 @@ impl Timeline {
&self,
layer_name: &LayerName,
) -> Result<Option<Layer>, layer_manager::Shutdown> {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
let layer = guard
.layer_map()?
.iter_historic_layers()
@@ -3920,10 +3902,7 @@ impl Timeline {
return None;
}
let guard = self
.layers
.read(LayerManagerLockHolder::GenerateHeatmap)
.await;
let guard = self.layers.read().await;
// Firstly, if there's any heatmap left over from when this location
// was a secondary, take that into account. Keep layers that are:
@@ -4021,10 +4000,7 @@ impl Timeline {
}
pub(super) async fn generate_unarchival_heatmap(&self, end_lsn: Lsn) -> PreviousHeatmap {
let guard = self
.layers
.read(LayerManagerLockHolder::GenerateHeatmap)
.await;
let guard = self.layers.read().await;
let now = SystemTime::now();
let mut heatmap_layers = Vec::default();
@@ -4366,7 +4342,7 @@ impl Timeline {
query: &VersionedKeySpaceQuery,
) -> Result<LayerFringe, GetVectoredError> {
let mut fringe = LayerFringe::new();
let guard = self.layers.read(LayerManagerLockHolder::GetPage).await;
let guard = self.layers.read().await;
match query {
VersionedKeySpaceQuery::Uniform { keyspace, lsn } => {
@@ -4469,7 +4445,7 @@ impl Timeline {
// required for correctness, but avoids visiting extra layers
// which turns out to be a perf bottleneck in some cases.
if !unmapped_keyspace.is_empty() {
let guard = timeline.layers.read(LayerManagerLockHolder::GetPage).await;
let guard = timeline.layers.read().await;
guard.update_search_fringe(&unmapped_keyspace, cont_lsn, &mut fringe)?;
// It's safe to drop the layer map lock after planning the next round of reads.
@@ -4579,10 +4555,7 @@ impl Timeline {
_guard: &tokio::sync::MutexGuard<'_, Option<TimelineWriterState>>,
ctx: &RequestContext,
) -> anyhow::Result<Arc<InMemoryLayer>> {
let mut guard = self
.layers
.write(LayerManagerLockHolder::GetLayerForWrite)
.await;
let mut guard = self.layers.write().await;
let last_record_lsn = self.get_last_record_lsn();
ensure!(
@@ -4624,10 +4597,7 @@ impl Timeline {
write_lock: &mut tokio::sync::MutexGuard<'_, Option<TimelineWriterState>>,
) -> Result<u64, FlushLayerError> {
let frozen = {
let mut guard = self
.layers
.write(LayerManagerLockHolder::TryFreezeLayer)
.await;
let mut guard = self.layers.write().await;
guard
.open_mut()?
.try_freeze_in_memory_layer(at, &self.last_freeze_at, write_lock, &self.metrics)
@@ -4668,12 +4638,7 @@ impl Timeline {
ctx: &RequestContext,
) {
// Subscribe to L0 delta layer updates, for compaction backpressure.
let mut watch_l0 = match self
.layers
.read(LayerManagerLockHolder::FlushLoop)
.await
.layer_map()
{
let mut watch_l0 = match self.layers.read().await.layer_map() {
Ok(lm) => lm.watch_level0_deltas(),
Err(Shutdown) => return,
};
@@ -4710,7 +4675,7 @@ impl Timeline {
// Fetch the next layer to flush, if any.
let (layer, l0_count, frozen_count, frozen_size) = {
let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await;
let layers = self.layers.read().await;
let Ok(lm) = layers.layer_map() else {
info!("dropping out of flush loop for timeline shutdown");
return;
@@ -5006,10 +4971,7 @@ impl Timeline {
// in-memory layer from the map now. The flushed layer is stored in
// the mapping in `create_delta_layer`.
{
let mut guard = self
.layers
.write(LayerManagerLockHolder::FlushFrozenLayer)
.await;
let mut guard = self.layers.write().await;
guard.open_mut()?.finish_flush_l0_layer(
delta_layer_to_add.as_ref(),
@@ -5224,7 +5186,7 @@ impl Timeline {
async fn time_for_new_image_layer(&self, partition: &KeySpace, lsn: Lsn) -> bool {
let threshold = self.get_image_creation_threshold();
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = self.layers.read().await;
let Ok(layers) = guard.layer_map() else {
return false;
};
@@ -5642,7 +5604,7 @@ impl Timeline {
if let ImageLayerCreationMode::Force = mode {
// When forced to create image layers, we might try and create them where they already
// exist. This mode is only used in tests/debug.
let layers = self.layers.read(LayerManagerLockHolder::Compaction).await;
let layers = self.layers.read().await;
if layers.contains_key(&PersistentLayerKey {
key_range: img_range.clone(),
lsn_range: PersistentLayerDesc::image_layer_lsn_range(lsn),
@@ -5767,7 +5729,7 @@ impl Timeline {
let image_layers = batch_image_writer.finish(self, ctx).await?;
let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await;
let mut guard = self.layers.write().await;
// FIXME: we could add the images to be uploaded *before* returning from here, but right
// now they are being scheduled outside of write lock; current way is inconsistent with
@@ -5775,7 +5737,7 @@ impl Timeline {
guard
.open_mut()?
.track_new_image_layers(&image_layers, &self.metrics);
drop_layer_manager_wlock(guard);
drop_wlock(guard);
let duration = timer.stop_and_record();
// Creating image layers may have caused some previously visible layers to be covered
@@ -6145,7 +6107,7 @@ impl Timeline {
layers_to_remove: &[Layer],
) -> Result<(), CompactionError> {
let mut guard = tokio::select! {
guard = self.layers.write(LayerManagerLockHolder::Compaction) => guard,
guard = self.layers.write() => guard,
_ = self.cancel.cancelled() => {
return Err(CompactionError::ShuttingDown);
}
@@ -6194,7 +6156,7 @@ impl Timeline {
self.remote_client
.schedule_compaction_update(&remove_layers, new_deltas)?;
drop_layer_manager_wlock(guard);
drop_wlock(guard);
Ok(())
}
@@ -6204,7 +6166,7 @@ impl Timeline {
mut replace_layers: Vec<(Layer, ResidentLayer)>,
mut drop_layers: Vec<Layer>,
) -> Result<(), CompactionError> {
let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await;
let mut guard = self.layers.write().await;
// Trim our lists in case our caller (compaction) raced with someone else (GC) removing layers: we want
// to avoid double-removing, and avoid rewriting something that was removed.
@@ -6555,10 +6517,7 @@ impl Timeline {
// 5. newer on-disk image layers cover the layer's whole key range
//
// TODO holding a write lock is too agressive and avoidable
let mut guard = self
.layers
.write(LayerManagerLockHolder::GarbageCollection)
.await;
let mut guard = self.layers.write().await;
let layers = guard.layer_map()?;
'outer: for l in layers.iter_historic_layers() {
result.layers_total += 1;
@@ -6860,10 +6819,7 @@ impl Timeline {
use pageserver_api::models::DownloadRemoteLayersTaskState;
let remaining = {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
let Ok(lm) = guard.layer_map() else {
// technically here we could look into iterating accessible layers, but downloading
// all layers of a shutdown timeline makes no sense regardless.
@@ -6969,7 +6925,7 @@ impl Timeline {
impl Timeline {
/// Returns non-remote layers for eviction.
pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo {
let guard = self.layers.read(LayerManagerLockHolder::Eviction).await;
let guard = self.layers.read().await;
let mut max_layer_size: Option<u64> = None;
let resident_layers = guard
@@ -7070,7 +7026,7 @@ impl Timeline {
let image_layer = Layer::finish_creating(self.conf, self, desc, &path)?;
info!("force created image layer {}", image_layer.local_path());
{
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = self.layers.write().await;
guard
.open_mut()
.unwrap()
@@ -7133,7 +7089,7 @@ impl Timeline {
let delta_layer = Layer::finish_creating(self.conf, self, desc, &path)?;
info!("force created delta layer {}", delta_layer.local_path());
{
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = self.layers.write().await;
guard
.open_mut()
.unwrap()
@@ -7228,7 +7184,7 @@ impl Timeline {
// Link the layer to the layer map
{
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = self.layers.write().await;
let layer_map = guard.open_mut().unwrap();
layer_map.force_insert_in_memory_layer(Arc::new(layer));
}
@@ -7245,7 +7201,7 @@ impl Timeline {
io_concurrency: IoConcurrency,
) -> anyhow::Result<Vec<(Key, Bytes)>> {
let mut all_data = Vec::new();
let guard = self.layers.read(LayerManagerLockHolder::Testing).await;
let guard = self.layers.read().await;
for layer in guard.layer_map()?.iter_historic_layers() {
if !layer.is_delta() && layer.image_layer_lsn() == lsn {
let layer = guard.get_from_desc(&layer);
@@ -7274,7 +7230,7 @@ impl Timeline {
self: &Arc<Timeline>,
) -> anyhow::Result<Vec<super::storage_layer::PersistentLayerKey>> {
let mut layers = Vec::new();
let guard = self.layers.read(LayerManagerLockHolder::Testing).await;
let guard = self.layers.read().await;
for layer in guard.layer_map()?.iter_historic_layers() {
layers.push(layer.key());
}
@@ -7386,7 +7342,7 @@ impl TimelineWriter<'_> {
let l0_count = self
.tl
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.read()
.await
.layer_map()?
.level0_deltas()
@@ -7605,7 +7561,6 @@ mod tests {
use crate::tenant::harness::{TenantHarness, test_img};
use crate::tenant::layer_map::LayerMap;
use crate::tenant::storage_layer::{Layer, LayerName, LayerVisibilityHint};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::timeline::{DeltaLayerTestDesc, EvictionError};
use crate::tenant::{PreviousHeatmap, Timeline};
@@ -7713,7 +7668,7 @@ mod tests {
// Evict all the layers and stash the old heatmap in the timeline.
// This simulates a migration to a cold secondary location.
let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let guard = timeline.layers.read().await;
let mut all_layers = Vec::new();
let forever = std::time::Duration::from_secs(120);
for layer in guard.likely_resident_layers() {
@@ -7835,7 +7790,7 @@ mod tests {
})));
// Evict all the layers in the previous heatmap
let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let guard = timeline.layers.read().await;
let forever = std::time::Duration::from_secs(120);
for layer in guard.likely_resident_layers() {
layer.evict_and_wait(forever).await.unwrap();
@@ -7898,10 +7853,7 @@ mod tests {
}
async fn find_some_layer(timeline: &Timeline) -> Layer {
let layers = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let layers = timeline.layers.read().await;
let desc = layers
.layer_map()
.unwrap()

View File

@@ -4,7 +4,6 @@ use std::ops::Range;
use utils::lsn::Lsn;
use super::Timeline;
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
#[derive(serde::Serialize)]
pub(crate) struct RangeAnalysis {
@@ -25,10 +24,7 @@ impl Timeline {
let num_of_l0;
let all_layer_files = {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
num_of_l0 = guard.layer_map().unwrap().level0_deltas().len();
guard.all_persistent_layers()
};

View File

@@ -9,7 +9,7 @@ use std::ops::{Deref, Range};
use std::sync::Arc;
use std::time::{Duration, Instant};
use super::layer_manager::{LayerManagerLockHolder, LayerManagerReadGuard};
use super::layer_manager::LayerManager;
use super::{
CompactFlags, CompactOptions, CompactionError, CreateImageLayersError, DurationRecorder,
GetVectoredError, ImageLayerCreationMode, LastImageLayerCreationStatus, RecordedDuration,
@@ -62,7 +62,7 @@ use crate::tenant::storage_layer::{
use crate::tenant::tasks::log_compaction_error;
use crate::tenant::timeline::{
DeltaLayerWriter, ImageLayerCreationOutcome, ImageLayerWriter, IoConcurrency, Layer,
ResidentLayer, drop_layer_manager_rlock,
ResidentLayer, drop_rlock,
};
use crate::tenant::{DeltaLayer, MaybeOffloaded};
use crate::virtual_file::{MaybeFatalIo, VirtualFile};
@@ -314,10 +314,7 @@ impl GcCompactionQueue {
.unwrap_or(Lsn::INVALID);
let layers = {
let guard = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = timeline.layers.read().await;
let layer_map = guard.layer_map()?;
layer_map.iter_historic_layers().collect_vec()
};
@@ -411,10 +408,7 @@ impl GcCompactionQueue {
timeline: &Arc<Timeline>,
lsn: Lsn,
) -> Result<u64, CompactionError> {
let guard = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = timeline.layers.read().await;
let layer_map = guard.layer_map()?;
let layers = layer_map.iter_historic_layers().collect_vec();
let mut size = 0;
@@ -857,7 +851,7 @@ impl KeyHistoryRetention {
}
let layer_generation;
{
let guard = tline.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = tline.layers.read().await;
if !guard.contains_key(key) {
return false;
}
@@ -1288,10 +1282,7 @@ impl Timeline {
// We do the repartition on the L0-L1 boundary. All data below the boundary
// are compacted by L0 with low read amplification, thus making the `repartition`
// function run fast.
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
guard
.all_persistent_layers()
.iter()
@@ -1470,7 +1461,7 @@ impl Timeline {
let latest_gc_cutoff = self.get_applied_gc_cutoff_lsn();
let pitr_cutoff = self.gc_info.read().unwrap().cutoffs.time;
let layers = self.layers.read(LayerManagerLockHolder::Compaction).await;
let layers = self.layers.read().await;
let layers_iter = layers.layer_map()?.iter_historic_layers();
let (layers_total, mut layers_checked) = (layers_iter.len(), 0);
for layer_desc in layers_iter {
@@ -1731,10 +1722,7 @@ impl Timeline {
// are implicitly left visible, because LayerVisibilityHint's default is Visible, and we never modify it here.
// Note that L0 deltas _can_ be covered by image layers, but we consider them 'visible' because we anticipate that
// they will be subject to L0->L1 compaction in the near future.
let layer_manager = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let layer_manager = self.layers.read().await;
let layer_map = layer_manager.layer_map()?;
let readable_points = {
@@ -1787,7 +1775,7 @@ impl Timeline {
};
let begin = tokio::time::Instant::now();
let phase1_layers_locked = self.layers.read(LayerManagerLockHolder::Compaction).await;
let phase1_layers_locked = self.layers.read().await;
let now = tokio::time::Instant::now();
stats.read_lock_acquisition_micros =
DurationRecorder::Recorded(RecordedDuration(now - begin), now);
@@ -1815,7 +1803,7 @@ impl Timeline {
/// Level0 files first phase of compaction, explained in the [`Self::compact_legacy`] comment.
async fn compact_level0_phase1<'a>(
self: &'a Arc<Self>,
guard: LayerManagerReadGuard<'a>,
guard: tokio::sync::RwLockReadGuard<'a, LayerManager>,
mut stats: CompactLevel0Phase1StatsBuilder,
target_file_size: u64,
force_compaction_ignore_threshold: bool,
@@ -2041,7 +2029,7 @@ impl Timeline {
holes
};
stats.read_lock_held_compute_holes_micros = stats.read_lock_held_key_sort_micros.till_now();
drop_layer_manager_rlock(guard);
drop_rlock(guard);
if self.cancel.is_cancelled() {
return Err(CompactionError::ShuttingDown);
@@ -2481,7 +2469,7 @@ impl Timeline {
// Find the top of the historical layers
let end_lsn = {
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = self.layers.read().await;
let layers = guard.layer_map()?;
let l0_deltas = layers.level0_deltas();
@@ -3020,7 +3008,7 @@ impl Timeline {
}
split_key_ranges.sort();
let all_layers = {
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
layer_map.iter_historic_layers().collect_vec()
};
@@ -3124,12 +3112,12 @@ impl Timeline {
.await?;
let jobs_len = jobs.len();
for (idx, job) in jobs.into_iter().enumerate() {
let sub_compaction_progress = format!("{}/{}", idx + 1, jobs_len);
info!(
"running enhanced gc bottom-most compaction, sub-compaction {}/{}",
idx + 1,
jobs_len
);
self.compact_with_gc_inner(cancel, job, ctx, yield_for_l0)
.instrument(info_span!(
"sub_compaction",
sub_compaction_progress = sub_compaction_progress
))
.await?;
}
if jobs_len == 0 {
@@ -3197,10 +3185,7 @@ impl Timeline {
// 1. If a layer is in the selection, all layers below it are in the selection.
// 2. Inferred from (1), for each key in the layer selection, the value can be reconstructed only with the layers in the layer selection.
let job_desc = {
let guard = self
.layers
.read(LayerManagerLockHolder::GarbageCollection)
.await;
let guard = self.layers.read().await;
let layers = guard.layer_map()?;
let gc_info = self.gc_info.read().unwrap();
let mut retain_lsns_below_horizon = Vec::new();
@@ -3971,10 +3956,7 @@ impl Timeline {
// First, do a sanity check to ensure the newly-created layer map does not contain overlaps.
let all_layers = {
let guard = self
.layers
.read(LayerManagerLockHolder::GarbageCollection)
.await;
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
layer_map.iter_historic_layers().collect_vec()
};
@@ -4038,10 +4020,7 @@ impl Timeline {
let update_guard = self.gc_compaction_layer_update_lock.write().await;
// Acquiring the update guard ensures current read operations end and new read operations are blocked.
// TODO: can we use `latest_gc_cutoff` Rcu to achieve the same effect?
let mut guard = self
.layers
.write(LayerManagerLockHolder::GarbageCollection)
.await;
let mut guard = self.layers.write().await;
guard
.open_mut()?
.finish_gc_compaction(&layer_selection, &compact_to, &self.metrics);
@@ -4109,11 +4088,7 @@ impl TimelineAdaptor {
pub async fn flush_updates(&mut self) -> Result<(), CompactionError> {
let layers_to_delete = {
let guard = self
.timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await;
let guard = self.timeline.layers.read().await;
self.layers_to_delete
.iter()
.map(|x| guard.get_from_desc(x))
@@ -4158,11 +4133,7 @@ impl CompactionJobExecutor for TimelineAdaptor {
) -> anyhow::Result<Vec<OwnArc<PersistentLayerDesc>>> {
self.flush_updates().await?;
let guard = self
.timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await;
let guard = self.timeline.layers.read().await;
let layer_map = guard.layer_map()?;
let result = layer_map
@@ -4201,11 +4172,7 @@ impl CompactionJobExecutor for TimelineAdaptor {
// this is a lot more complex than a simple downcast...
if layer.is_delta() {
let l = {
let guard = self
.timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await;
let guard = self.timeline.layers.read().await;
guard.get_from_desc(layer)
};
let result = l.download_and_keep_resident(ctx).await?;

View File

@@ -19,7 +19,7 @@ use utils::id::TimelineId;
use utils::lsn::Lsn;
use utils::sync::gate::GateError;
use super::layer_manager::{LayerManager, LayerManagerLockHolder};
use super::layer_manager::LayerManager;
use super::{FlushLayerError, Timeline};
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::TaskKind;
@@ -199,10 +199,7 @@ pub(crate) async fn generate_tombstone_image_layer(
let image_lsn = ancestor_lsn;
{
let layers = detached
.layers
.read(LayerManagerLockHolder::DetachAncestor)
.await;
let layers = detached.layers.read().await;
for layer in layers.all_persistent_layers() {
if !layer.is_delta
&& layer.lsn_range.start == image_lsn
@@ -426,7 +423,7 @@ pub(super) async fn prepare(
// we do not need to start from our layers, because they can only be layers that come
// *after* ancestor_lsn
let layers = tokio::select! {
guard = ancestor.layers.read(LayerManagerLockHolder::DetachAncestor) => guard,
guard = ancestor.layers.read() => guard,
_ = detached.cancel.cancelled() => {
return Err(ShuttingDown);
}
@@ -872,12 +869,7 @@ async fn remote_copy(
// Double check that the file is orphan (probably from an earlier attempt), then delete it
let key = file_name.clone().into();
if adoptee
.layers
.read(LayerManagerLockHolder::DetachAncestor)
.await
.contains_key(&key)
{
if adoptee.layers.read().await.contains_key(&key) {
// We are supposed to filter out such cases before coming to this function
return Err(Error::Prepare(anyhow::anyhow!(
"layer file {file_name} already present and inside layer map"

View File

@@ -33,7 +33,6 @@ use crate::tenant::size::CalculateSyntheticSizeError;
use crate::tenant::storage_layer::LayerVisibilityHint;
use crate::tenant::tasks::{BackgroundLoopKind, BackgroundLoopSemaphorePermit, sleep_random};
use crate::tenant::timeline::EvictionError;
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::{LogicalSizeCalculationCause, TenantShard};
#[derive(Default)]
@@ -209,7 +208,7 @@ impl Timeline {
let mut js = tokio::task::JoinSet::new();
{
let guard = self.layers.read(LayerManagerLockHolder::Eviction).await;
let guard = self.layers.read().await;
guard
.likely_resident_layers()

View File

@@ -15,7 +15,6 @@ use super::{Timeline, TimelineDeleteProgress};
use crate::context::RequestContext;
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
use crate::tenant::metadata::TimelineMetadata;
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
mod flow;
mod importbucket_client;
@@ -164,10 +163,7 @@ async fn prepare_import(
info!("wipe the slate clean");
{
// TODO: do we need to hold GC lock for this?
let mut guard = timeline
.layers
.write(LayerManagerLockHolder::ImportPgData)
.await;
let mut guard = timeline.layers.write().await;
assert!(
guard.layer_map()?.open_layer.is_none(),
"while importing, there should be no in-memory layer" // this just seems like a good place to assert it

View File

@@ -56,7 +56,6 @@ use crate::pgdatadir_mapping::{
};
use crate::task_mgr::TaskKind;
use crate::tenant::storage_layer::{AsLayerDesc, ImageLayerWriter, Layer};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
pub async fn run(
timeline: Arc<Timeline>,
@@ -985,10 +984,7 @@ impl ChunkProcessingJob {
let (desc, path) = writer.finish(ctx).await?;
{
let guard = timeline
.layers
.read(LayerManagerLockHolder::ImportPgData)
.await;
let guard = timeline.layers.read().await;
let existing_layer = guard.try_get_from_key(&desc.key());
if let Some(layer) = existing_layer {
if layer.metadata().generation == timeline.generation {
@@ -1011,10 +1007,7 @@ impl ChunkProcessingJob {
// certain that the existing layer is identical to the new one, so in that case
// we replace the old layer with the one we just generated.
let mut guard = timeline
.layers
.write(LayerManagerLockHolder::ImportPgData)
.await;
let mut guard = timeline.layers.write().await;
let existing_layer = guard
.try_get_from_key(&resident_layer.layer_desc().key())
@@ -1043,7 +1036,7 @@ impl ChunkProcessingJob {
}
}
crate::tenant::timeline::drop_layer_manager_wlock(guard);
crate::tenant::timeline::drop_wlock(guard);
timeline
.remote_client

View File

@@ -1,8 +1,5 @@
use std::collections::HashMap;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail, ensure};
use itertools::Itertools;
@@ -23,155 +20,6 @@ use crate::tenant::storage_layer::{
PersistentLayerKey, ReadableLayerWeak, ResidentLayer,
};
/// Warn if the lock was held for longer than this threshold.
/// It's very generous and we should bring this value down over time.
const LAYER_MANAGER_LOCK_WARN_THRESHOLD: Duration = Duration::from_secs(5);
const LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD: Duration = Duration::from_secs(30);
/// Describes the operation that is holding the layer manager lock
#[derive(Debug, Clone, Copy, strum_macros::Display)]
#[strum(serialize_all = "kebab_case")]
pub(crate) enum LayerManagerLockHolder {
GetLayerMapInfo,
GenerateHeatmap,
GetPage,
Init,
LoadLayerMap,
GetLayerForWrite,
TryFreezeLayer,
FlushFrozenLayer,
FlushLoop,
Compaction,
GarbageCollection,
Shutdown,
ImportPgData,
DetachAncestor,
Eviction,
#[cfg(test)]
Testing,
}
/// Wrapper for the layer manager that tracks the amount of time during which
/// it was held under read or write lock
#[derive(Default)]
pub(crate) struct LockedLayerManager {
locked: tokio::sync::RwLock<LayerManager>,
}
pub(crate) struct LayerManagerReadGuard<'a> {
guard: ManuallyDrop<tokio::sync::RwLockReadGuard<'a, LayerManager>>,
acquired_at: std::time::Instant,
holder: LayerManagerLockHolder,
}
pub(crate) struct LayerManagerWriteGuard<'a> {
guard: ManuallyDrop<tokio::sync::RwLockWriteGuard<'a, LayerManager>>,
acquired_at: std::time::Instant,
holder: LayerManagerLockHolder,
}
impl Drop for LayerManagerReadGuard<'_> {
fn drop(&mut self) {
// Drop the lock first, before potentially warning if it was held for too long.
// SAFETY: ManuallyDrop in Drop implementation
unsafe { ManuallyDrop::drop(&mut self.guard) };
let held_for = self.acquired_at.elapsed();
if held_for >= LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD {
tracing::warn!(
holder=%self.holder,
"Layer manager read lock held for {}s",
held_for.as_secs_f64(),
);
}
}
}
impl Drop for LayerManagerWriteGuard<'_> {
fn drop(&mut self) {
// Drop the lock first, before potentially warning if it was held for too long.
// SAFETY: ManuallyDrop in Drop implementation
unsafe { ManuallyDrop::drop(&mut self.guard) };
let held_for = self.acquired_at.elapsed();
if held_for >= LAYER_MANAGER_LOCK_WARN_THRESHOLD {
tracing::warn!(
holder=%self.holder,
"Layer manager write lock held for {}s",
held_for.as_secs_f64(),
);
}
}
}
impl Deref for LayerManagerReadGuard<'_> {
type Target = LayerManager;
fn deref(&self) -> &Self::Target {
self.guard.deref()
}
}
impl Deref for LayerManagerWriteGuard<'_> {
type Target = LayerManager;
fn deref(&self) -> &Self::Target {
self.guard.deref()
}
}
impl DerefMut for LayerManagerWriteGuard<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.guard.deref_mut()
}
}
impl LockedLayerManager {
pub(crate) async fn read(&self, holder: LayerManagerLockHolder) -> LayerManagerReadGuard {
let guard = ManuallyDrop::new(self.locked.read().await);
LayerManagerReadGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
}
}
pub(crate) fn try_read(
&self,
holder: LayerManagerLockHolder,
) -> Result<LayerManagerReadGuard, tokio::sync::TryLockError> {
let guard = ManuallyDrop::new(self.locked.try_read()?);
Ok(LayerManagerReadGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
})
}
pub(crate) async fn write(&self, holder: LayerManagerLockHolder) -> LayerManagerWriteGuard {
let guard = ManuallyDrop::new(self.locked.write().await);
LayerManagerWriteGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
}
}
pub(crate) fn try_write(
&self,
holder: LayerManagerLockHolder,
) -> Result<LayerManagerWriteGuard, tokio::sync::TryLockError> {
let guard = ManuallyDrop::new(self.locked.try_write()?);
Ok(LayerManagerWriteGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
})
}
}
/// Provides semantic APIs to manipulate the layer map.
pub(crate) enum LayerManager {
/// Open as in not shutdown layer manager; we still have in-memory layers and we can manipulate

View File

@@ -1092,15 +1092,13 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
MyPState->ring_last <= ring_index);
}
/* Internal version. Returns the ring index of the last block (result of this function is used only
* when nblocks==1)
*/
/* internal version. Returns the ring index */
static uint64
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
BlockNumber nblocks, const bits8 *mask,
bool is_prefetch)
{
uint64 last_ring_index;
uint64 min_ring_index;
PrefetchRequest hashkey;
#ifdef USE_ASSERT_CHECKING
bool any_hits = false;
@@ -1124,12 +1122,13 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
MyNeonCounters->getpage_prefetches_buffered =
MyPState->n_responses_buffered;
last_ring_index = UINT64_MAX;
min_ring_index = UINT64_MAX;
for (int i = 0; i < nblocks; i++)
{
PrefetchRequest *slot = NULL;
PrfHashEntry *entry = NULL;
uint64 ring_index;
neon_request_lsns *lsns;
if (PointerIsValid(mask) && BITMAP_ISSET(mask, i))
@@ -1153,12 +1152,12 @@ Retry:
if (entry != NULL)
{
slot = entry->slot;
last_ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(last_ring_index));
ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(ring_index));
Assert(slot->status != PRFS_UNUSED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(MyPState->ring_last <= ring_index &&
ring_index < MyPState->ring_unused);
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
/*
@@ -1170,9 +1169,9 @@ Retry:
if (!neon_prefetch_response_usable(lsns, slot))
{
/* Wait for the old request to finish and discard it */
if (!prefetch_wait_for(last_ring_index))
if (!prefetch_wait_for(ring_index))
goto Retry;
prefetch_set_unused(last_ring_index);
prefetch_set_unused(ring_index);
entry = NULL;
slot = NULL;
pgBufferUsage.prefetch.expired += 1;
@@ -1189,12 +1188,13 @@ Retry:
*/
if (slot->status == PRFS_TAG_REMAINS)
{
prefetch_set_unused(last_ring_index);
prefetch_set_unused(ring_index);
entry = NULL;
slot = NULL;
}
else
{
min_ring_index = Min(min_ring_index, ring_index);
/* The buffered request is good enough, return that index */
if (is_prefetch)
pgBufferUsage.prefetch.duplicates++;
@@ -1283,12 +1283,12 @@ Retry:
* The next buffer pointed to by `ring_unused` is now definitely empty, so
* we can insert the new request to it.
*/
last_ring_index = MyPState->ring_unused;
ring_index = MyPState->ring_unused;
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index <= MyPState->ring_unused);
Assert(MyPState->ring_last <= ring_index &&
ring_index <= MyPState->ring_unused);
slot = GetPrfSlotNoCheck(last_ring_index);
slot = GetPrfSlotNoCheck(ring_index);
Assert(slot->status == PRFS_UNUSED);
@@ -1298,9 +1298,11 @@ Retry:
*/
slot->buftag = hashkey.buftag;
slot->shard_no = get_shard_number(&tag);
slot->my_ring_index = last_ring_index;
slot->my_ring_index = ring_index;
slot->flags = 0;
min_ring_index = Min(min_ring_index, ring_index);
if (is_prefetch)
MyNeonCounters->getpage_prefetch_requests_total++;
else
@@ -1313,12 +1315,11 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
Assert(any_hits);
Assert(last_ring_index != UINT64_MAX);
Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= min_ring_index &&
min_ring_index < MyPState->ring_unused);
if (flush_every_n_requests > 0 &&
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
@@ -1334,7 +1335,7 @@ Retry:
MyPState->ring_flush = MyPState->ring_unused;
}
return last_ring_index;
return min_ring_index;
}
static bool

View File

@@ -2,6 +2,6 @@ DROP FUNCTION IF EXISTS get_prewarm_info(out total_pages integer, out prewarmed_
DROP FUNCTION IF EXISTS get_local_cache_state(max_chunks integer);
DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer);
DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer default 1);

View File

@@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf
wp->propTermStartLsn = sk->voteResponse.flushLsn;
wp->donor = sk;
}
wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn);
wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn);
if (n_votes > 0)
appendStringInfoString(s, ", ");

10
poetry.lock generated
View File

@@ -3051,19 +3051,19 @@ files = [
[[package]]
name = "requests"
version = "2.32.4"
version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"},
{file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"},
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
certifi = ">=2017.4.17"
charset_normalizer = ">=2,<4"
charset-normalizer = ">=2,<4"
idna = ">=2.5,<4"
urllib3 = ">=1.21.1,<3"
@@ -3846,4 +3846,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = "^3.11"
content-hash = "bd93313f110110aa53b24a3ed47ba2d7f60e2c658a79cdff7320fed1bb1b57b5"
content-hash = "7ab1e7b975af34b3271b7c6018fa22a261d3f73c7c0a0403b6b2bb86b5fbd36e"

View File

@@ -89,6 +89,7 @@ tokio-postgres = { workspace = true, optional = true }
tokio-rustls.workspace = true
tokio-util.workspace = true
tokio = { workspace = true, features = ["signal"] }
toml.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true

View File

@@ -14,9 +14,9 @@ use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::stream::PqStream;
use crate::types::RoleName;
use crate::{auth, compute, waiters};
@@ -109,7 +109,7 @@ impl ConsoleRedirectBackend {
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
#[async_trait]
impl WakeComputeBackend for ConsoleRedirectNodeInfo {
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
async fn wake_compute(
&self,
_ctx: &RequestContext,

View File

@@ -4,8 +4,6 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use arc_swap::ArcSwapOption;
use base64::Engine as _;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use clashmap::ClashMap;
use jose_jwk::crypto::KeyInfo;
use reqwest::{Client, redirect};
@@ -349,17 +347,17 @@ impl JwkCacheEntryLock {
.split_once('.')
.ok_or(JwtEncodingError::InvalidCompactForm)?;
let header = BASE64_URL_SAFE_NO_PAD.decode(header)?;
let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
let payloadb = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
if let Some(iss) = &payload.issuer {
ctx.set_jwt_issuer(iss.as_ref().to_owned());
}
let sig = BASE64_URL_SAFE_NO_PAD.decode(signature)?;
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
@@ -798,6 +796,7 @@ mod tests {
use std::net::SocketAddr;
use std::time::SystemTime;
use base64::URL_SAFE_NO_PAD;
use bytes::Bytes;
use http::Response;
use http_body_util::Full;
@@ -872,8 +871,9 @@ mod tests {
key_id: Some(Cow::Owned(kid)),
};
let header = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
let body = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&body).unwrap());
let header =
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
format!("{header}.{body}")
}
@@ -883,7 +883,7 @@ mod tests {
let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}
@@ -893,7 +893,7 @@ mod tests {
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}
@@ -904,7 +904,7 @@ mod tests {
let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}

View File

@@ -14,21 +14,20 @@ use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
@@ -231,8 +230,11 @@ async fn auth_quirks(
config.is_vpc_acccess_proxy,
)?;
access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?;
let endpoint = EndpointIdInt::from(&info.endpoint);
let rate_limit_config = None;
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let role_access = api
.get_role_access_control(ctx, &info.endpoint, &info.user)
.await?;
@@ -399,20 +401,19 @@ impl Backend<'_, ComputeUserInfo> {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
}),
}
}
}
#[async_trait::async_trait]
impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
@@ -438,7 +439,6 @@ mod tests {
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
};
@@ -477,7 +477,6 @@ mod tests {
allowed_ips: Arc::new(self.ips.clone()),
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
flags: self.access_blocker_flags,
rate_limits: EndpointRateLimitConfig::default(),
})
}

View File

@@ -279,7 +279,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,

View File

@@ -28,9 +28,10 @@ use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
use crate::proxy::{
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
};
use crate::stream::{PqStream, Stream};
use crate::util::run_until_cancelled;
project_git_version!(GIT_VERSION);
@@ -236,7 +237,6 @@ pub(super) async fn task_main(
extra: None,
},
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}

View File

@@ -8,16 +8,15 @@ use std::time::Duration;
#[cfg(any(test, feature = "testing"))]
use anyhow::Context;
use anyhow::{bail, ensure};
use anyhow::{bail, anyhow};
use arc_swap::ArcSwapOption;
use futures::future::Either;
use itertools::{Itertools, Position};
use rand::{Rng, thread_rng};
use remote_storage::RemoteStorageConfig;
use serde::Deserialize;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
use tracing::{Instrument, info};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
@@ -41,7 +40,7 @@ use crate::serverless::cancel_set::CancelSet;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
use crate::{auth, control_plane, http, serverless, usage_metrics};
use crate::{auth, control_plane, http, pglb, serverless, usage_metrics};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
@@ -61,6 +60,262 @@ enum AuthBackendType {
Postgres,
}
#[derive(Deserialize)]
struct Root {
#[serde(flatten)]
legacy: LegacyModes,
introspection: Introspection,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum LegacyModes {
Proxy {
pglb: Pglb,
neonkeeper: NeonKeeper,
http: Option<Http>,
pg_sni_router: Option<PgSniRouter>,
},
AuthBroker {
neonkeeper: NeonKeeper,
http: Http,
},
ConsoleRedirect {
console_redirect: ConsoleRedirect,
},
}
#[derive(Deserialize)]
struct Pglb {
listener: Listener,
}
#[derive(Deserialize)]
struct Listener {
/// address to bind to
addr: SocketAddr,
/// which header should we expect to see on this socket
/// from the load balancer
header: Option<ProxyHeader>,
/// certificates used for TLS.
/// first cert is the default.
/// TLS not used if no certs provided.
certs: Vec<KeyPair>,
/// Timeout to use for TLS handshake
timeout: Option<Duration>,
}
#[derive(Deserialize)]
enum ProxyHeader {
/// Accept the PROXY! protocol V2.
ProxyProtocolV2(ProxyProtocolV2Kind),
}
#[derive(Deserialize)]
enum ProxyProtocolV2Kind {
/// Expect AWS TLVs in the header.
Aws,
/// Expect Azure TLVs in the header.
Azure,
}
#[derive(Deserialize)]
struct KeyPair {
key: PathBuf,
cert: PathBuf,
}
#[derive(Deserialize)]
/// The service that authenticates all incoming connection attempts,
/// provides monitoring and also wakes computes.
struct NeonKeeper {
cplane: ControlPlaneBackend,
redis: Option<Redis>,
auth: Vec<AuthMechanism>,
/// map of endpoint->computeinfo
compute: Cache,
/// cache for GetEndpointAccessControls.
project_info_cache: config::ProjectInfoCacheOptions,
/// cache for all valid endpoints
endpoint_cache_config: config::EndpointCacheConfig,
request_log_export: Option<RequestLogExport>,
data_transfer_export: Option<DataTransferExport>,
}
#[derive(Deserialize)]
struct Redis {
/// Cancellation channel size (max queue size for redis kv client)
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
cancellation_batch_size: usize,
auth: RedisAuthentication,
}
#[derive(Deserialize)]
enum RedisAuthentication {
/// i don't remember what this stands for.
/// IAM roles for service accounts?
Irsa {
host: String,
port: u16,
cluster_name: Option<String>,
user_id: Option<String>,
aws_region: String,
},
Basic {
url: url::Url,
},
}
#[derive(Deserialize)]
struct PgSniRouter {
/// The listener to use to proxy connections to compute,
/// assuming the compute does not support TLS.
listener: Listener,
/// The listener to use to proxy connections to compute,
/// assuming the compute requires TLS.
listener_tls: Listener,
/// append this domain zone to the SNI hostname to get the destination address
dest: String,
}
#[derive(Deserialize)]
/// `psql -h pg.neon.tech`.
struct ConsoleRedirect {
/// Connection requests from clients.
listener: Listener,
/// Messages from control plane to accept the connection.
cplane: Listener,
/// The base url to use for redirects.
console: url::Url,
timeout: Duration,
}
#[derive(Deserialize)]
enum ControlPlaneBackend {
/// Use the HTTP API to access the control plane.
Http(url::Url),
/// Stub the control plane with a postgres instance.
#[cfg(feature = "testing")]
PostgresMock(url::Url),
}
#[derive(Deserialize)]
struct Http {
listener: Listener,
sql_over_http: SqlOverHttp,
// todo: move into Pglb.
websockets: Option<Websockets>,
}
#[derive(Deserialize)]
struct SqlOverHttp {
pool_max_conns_per_endpoint: usize,
pool_max_total_conns: usize,
pool_idle_timeout: Duration,
pool_gc_epoch: Duration,
pool_shards: usize,
client_conn_threshold: u64,
cancel_set_shards: usize,
timeout: Duration,
max_request_size_bytes: usize,
max_response_size_bytes: usize,
auth: Vec<AuthMechanism>,
}
#[derive(Deserialize)]
enum AuthMechanism {
Sasl {
/// timeout for SASL handshake
timeout: Duration,
},
CleartextPassword {
/// number of threads for the thread pool
threads: usize,
},
// add something about the jwks cache i guess.
Jwt {},
}
#[derive(Deserialize)]
struct Websockets {
auth: Vec<AuthMechanism>,
}
#[derive(Deserialize)]
/// The HTTP API used for internal monitoring.
struct Introspection {
listener: Listener,
}
#[derive(Deserialize)]
enum RequestLogExport {
Parquet {
location: RemoteStorageConfig,
disconnect: RemoteStorageConfig,
/// The region identifier to tag the entries with.
region: String,
/// How many rows to include in a row group
row_group_size: usize,
/// How large each column page should be in bytes
page_size: usize,
/// How large the total parquet file should be in bytes
size: i64,
/// How long to wait before forcing a file upload
maximum_duration: tokio::time::Duration,
// /// What level of compression to use
// compression: Compression,
},
}
#[derive(Deserialize)]
enum Cache {
/// Expire by LRU or by idle.
/// Note: "live" in "time-to-live" actually means idle here.
LruTtl {
/// Max number of entries.
size: usize,
/// Entry's time-to-live.
ttl: Duration,
},
}
#[derive(Deserialize)]
struct DataTransferExport {
/// http endpoint to receive periodic metric updates
endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
interval: Option<String>,
/// interval for backup metric collection
backup_interval: std::time::Duration,
/// remote storage configuration for backup metric collection
/// Encoded as toml (same format as pageservers), eg
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
backup_remote_storage: Option<RemoteStorageConfig>,
/// chunk size for backup metric collection
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
backup_chunk_size: usize,
}
/// Neon proxy/router
#[derive(Parser)]
#[command(version = GIT_VERSION, about)]
@@ -122,12 +377,6 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
@@ -154,40 +403,31 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Cancellation channel size (max queue size for redis kv client)
#[clap(long, default_value_t = 1024)]
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
#[clap(long, default_value_t = 8)]
cancellation_batch_size: usize,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
/// redis url for plain authentication
#[clap(long, alias("redis-notifications"))]
redis_plain: Option<String>,
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
#[clap(long)]
redis_notifications: Option<String>,
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
#[clap(long, default_value = "irsa")]
redis_auth_type: String,
/// redis host for streaming connections (might be different from the notifications host)
redis_auth_type: Option<String>,
/// redis host for irsa authentication
#[clap(long)]
redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host)
/// redis port for irsa authentication
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache
/// redis cluster name for irsa authentication
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache
/// redis user_id for irsa authentication
#[clap(long)]
redis_user_id: Option<String>,
/// aws region to retrieve credentials
/// aws region for irsa authentication
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
@@ -199,6 +439,12 @@ struct ProxyCliArgs {
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration,
@@ -211,6 +457,7 @@ struct ProxyCliArgs {
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
#[clap(long, default_value = "4194304")]
metric_backup_collection_chunk_size: usize,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -316,81 +563,252 @@ pub async fn run() -> anyhow::Result<()> {
let jemalloc = match crate::jemalloc::MetricRecorder::new() {
Ok(t) => Some(t),
Err(e) => {
error!(error = ?e, "could not start jemalloc metrics loop");
tracing::error!(error = ?e, "could not start jemalloc metrics loop");
None
}
};
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
let config: Root = toml::from_str(&tokio::fs::read_to_string("proxy.toml").await?)?;
match auth_backend {
Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"),
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
// Check that we can bind to address before further initialization
info!("Starting http on {}", args.http);
let http_listener = TcpListener::bind(args.http).await?.into_std()?;
info!("Starting mgmt on {}", args.mgmt);
let mgmt_listener = TcpListener::bind(args.mgmt).await?;
let proxy_listener = if args.is_auth_broker {
None
} else {
info!("Starting proxy on {}", args.proxy);
Some(TcpListener::bind(args.proxy).await?)
};
let sni_router_listeners = {
let args = &args.pg_sni_router;
if args.dest.is_some() {
ensure!(
args.tls_key.is_some(),
"sni-router-tls-key must be provided"
);
ensure!(
args.tls_cert.is_some(),
"sni-router-tls-cert must be provided"
);
info!(
"Starting pg-sni-router on {} and {}",
args.listen, args.listen_tls
);
Some((
TcpListener::bind(args.listen).await?,
TcpListener::bind(args.listen_tls).await?,
))
} else {
None
}
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
let cancellation_token = CancellationToken::new();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
match config.legacy {
LegacyModes::Proxy {
pglb,
neonkeeper,
http,
pg_sni_router,
} => {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
// todo: use neonkeeper config.
EndpointRateLimiter::DEFAULT,
64,
));
let redis_kv_client = regional_redis_client
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
info!("Starting proxy on {}", pglb.listener.addr);
let proxy_listener = TcpListener::bind(pglb.listener.addr).await?;
client_tasks.spawn(crate::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
if let Some(http) = http {
info!("Starting wss on {}", http.listener.addr);
let http_listener = TcpListener::bind(http.listener.addr).await?;
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
http_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
};
if let Some(redis) = neonkeeper.redis {
let client = configure_redis(redis.auth);
}
if let Some(sni_router) = pg_sni_router {
let listen = TcpListener::bind(sni_router.listener.addr).await?;
let listen_tls = TcpListener::bind(sni_router.listener_tls.addr).await?;
let [KeyPair { key, cert }] = sni_router
.listener
.certs
.try_into()
.map_err(|_| anyhow!("only 1 keypair is supported for pg-sni-router"))?;
let tls_config = super::pg_sni_router::parse_tls(&key, &cert)?;
let dest = Arc::new(sni_router.dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
listen_tls,
cancellation_token.clone(),
));
}
match neonkeeper.request_log_export {
Some(RequestLogExport::Parquet {
location,
disconnect,
region,
row_group_size,
page_size,
size,
maximum_duration,
}) => {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
}
None => {}
}
if let (ControlPlaneBackend::Http(api), Some(redis)) =
(neonkeeper.cplane, neonkeeper.redis)
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
drop(redis_kv_client);
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
LegacyModes::AuthBroker { neonkeeper, http } => {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
// todo: use neonkeeper config.
EndpointRateLimiter::DEFAULT,
64,
));
info!("Starting wss on {}", http.listener.addr);
let http_listener = TcpListener::bind(http.listener.addr).await?;
if let Some(redis) = neonkeeper.redis {
let client = configure_redis(redis.auth);
}
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
match neonkeeper.request_log_export {
Some(RequestLogExport::Parquet {
location,
disconnect,
region,
row_group_size,
page_size,
size,
maximum_duration,
}) => {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
}
None => {}
}
if let (ControlPlaneBackend::Http(api), Some(redis)) =
(neonkeeper.cplane, neonkeeper.redis)
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
drop(redis_kv_client);
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
LegacyModes::ConsoleRedirect { console_redirect } => {
info!("Starting proxy on {}", console_redirect.listener.addr);
let proxy_listener = TcpListener::bind(console_redirect.listener.addr).await?;
info!("Starting mgmt on {}", console_redirect.listener.addr);
let mgmt_listener = TcpListener::bind(console_redirect.listener.addr).await?;
client_tasks.spawn(crate::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
}
}
// Check that we can bind to address before further initialization
info!("Starting http on {}", config.introspection.listener.addr);
let http_listener = TcpListener::bind(config.introspection.listener.addr)
.await?
.into_std()?;
// channel size should be higher than redis client limit to avoid blocking
let cancel_ch_size = args.cancellation_ch_size;
@@ -400,87 +818,6 @@ pub async fn run() -> anyhow::Result<()> {
Some(tx_cancel),
));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
.unwrap_or(EndpointRateLimiter::DEFAULT),
64,
));
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
}
}
}
// spawn pg-sni-router mode.
if let Some((listen, listen_tls)) = sni_router_listeners {
let args = args.pg_sni_router;
let dest = args.dest.expect("already asserted it is set");
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
listen_tls,
cancellation_token.clone(),
));
}
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
));
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
@@ -490,90 +827,12 @@ pub async fn run() -> anyhow::Result<()> {
proxy: crate::metrics::Metrics::get(),
},
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
if let Some(mut redis_kv_client) = redis_kv_client {
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
maintenance_tasks.spawn(async move {
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
drop(redis_kv_client);
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
}
}
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
}
}
let maintenance = loop {
// get one complete task
match futures::future::select(
@@ -696,7 +955,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
@@ -856,58 +1114,45 @@ fn build_auth_backend(
}
}
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.clone(),
port,
elasticache::CredentialsProvider::new(
args.aws_region.clone(),
args.redis_cluster_name.clone(),
args.redis_user_id.clone(),
)
.await,
),
),
(None, None) => {
// todo: upgrade to error?
warn!(
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
);
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
async fn configure_redis(auth: RedisAuthentication) -> ConnectionWithCredentialsProvider {
match auth {
RedisAuthentication::Irsa {
host,
port,
cluster_name,
user_id,
aws_region,
} => ConnectionWithCredentialsProvider::new_with_credentials_provider(
host,
port,
elasticache::CredentialsProvider::new(aws_region, cluster_name, user_id).await,
),
RedisAuthentication::Basic { url } => {
ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())
}
}
}
None => None,
};
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
// } else {
// regional_redis_client.clone()
// };
Ok(redis_client)
}
None => None,
};
Ok((regional_redis_client, redis_notifications_client))
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
// } else {
// regional_redis_client.clone()
// };
Ok(redis_client)
}
#[cfg(test)]

View File

@@ -364,7 +364,6 @@ mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
@@ -400,7 +399,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret1.clone(),
@@ -416,7 +414,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret2.clone(),
@@ -442,7 +439,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret3.clone(),

View File

@@ -136,11 +136,11 @@ impl AuthInfo {
}
}
pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
Some(Auth::Scram(Box::new(*auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},

View File

@@ -22,7 +22,6 @@ pub struct ProxyConfig {
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
@@ -70,7 +69,7 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
#[derive(Debug)]
#[derive(Debug, serde::Deserialize)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
pub initial_batch_size: usize,
@@ -206,7 +205,7 @@ impl FromStr for CacheOptions {
}
/// Helper for cmdline cache options parsing.
#[derive(Debug)]
#[derive(Debug, serde::Deserialize)]
pub struct ProjectInfoCacheOptions {
/// Max number of entries.
pub size: usize,

View File

@@ -11,12 +11,13 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
use crate::util::run_until_cancelled;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
pub async fn task_main(
config: &'static ProxyConfig,
@@ -89,12 +90,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,

View File

@@ -46,7 +46,6 @@ struct RequestContextInner {
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) span: Span,
// filled in as they are discovered
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"),
project: inner.project,
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
}
impl RequestContext {
pub fn new(
session_id: Uuid,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
"connect_request",
@@ -145,7 +138,6 @@ impl RequestContext {
session_id,
protocol,
first_packet: Utc::now(),
region,
span,
project: None,
@@ -179,7 +171,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
}
pub(crate) fn console_application_name(&self) -> String {

View File

@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
#[derive(parquet_derive::ParquetRecordWriter)]
pub(crate) struct RequestData {
region: &'static str,
region: String,
protocol: &'static str,
/// Must be UTC. The derive macro doesn't like the timezones
timestamp: chrono::NaiveDateTime,
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
}),
jwt_issuer: value.jwt_issuer.clone(),
protocol: value.protocol.as_str(),
region: value.region,
region: String::new(),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
cold_start_info: value.cold_start_info.as_str(),
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
pub async fn worker(
cancellation_token: CancellationToken,
config: ParquetUploadArgs,
region: String,
) -> anyhow::Result<()> {
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
tracing::warn!("parquet request upload: no s3 bucket configured");
@@ -232,12 +233,17 @@ pub async fn worker(
.context("remote storage for disconnect events init")?;
let parquet_config_disconnect = parquet_config.clone();
tokio::try_join!(
worker_inner(storage, rx, parquet_config),
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
worker_inner(storage, rx, parquet_config, &region),
worker_inner(
storage_disconnect,
rx_disconnect,
parquet_config_disconnect,
&region
)
)
.map(|_| ())
} else {
worker_inner(storage, rx, parquet_config).await
worker_inner(storage, rx, parquet_config, &region).await
}
}
@@ -257,6 +263,7 @@ async fn worker_inner(
storage: GenericRemoteStorage,
rx: impl Stream<Item = RequestData>,
config: ParquetConfig,
region: &str,
) -> anyhow::Result<()> {
#[cfg(any(test, feature = "testing"))]
let storage = if config.test_remote_failures > 0 {
@@ -277,7 +284,8 @@ async fn worker_inner(
let mut last_upload = time::Instant::now();
let mut len = 0;
while let Some(row) = rx.next().await {
while let Some(mut row) = rx.next().await {
region.clone_into(&mut row.region);
rows.push(row);
let force = last_upload.elapsed() > config.max_duration;
if rows.len() == config.rows_per_group || force {
@@ -533,7 +541,7 @@ mod tests {
auth_method: None,
jwt_issuer: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
region: String::new(),
error: None,
success: rng.r#gen(),
cold_start_info: "no",
@@ -565,7 +573,9 @@ mod tests {
.await
.unwrap();
worker_inner(storage, rx, config).await.unwrap();
worker_inner(storage, rx, config, "us-east-1")
.await
.unwrap();
let mut files = WalkDir::new(tmpdir.as_std_path())
.into_iter()

View File

@@ -146,7 +146,6 @@ impl NeonControlPlaneClient {
public_access_blocked: block_public_connections,
vpc_access_blocked: block_vpc_connections,
},
rate_limits: body.rate_limits,
})
}
.inspect_err(|e| tracing::debug!(error = ?e))
@@ -313,7 +312,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
@@ -359,7 +357,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,

View File

@@ -20,7 +20,7 @@ use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
@@ -130,7 +130,6 @@ impl MockControlPlane {
project_id: None,
account_id: None,
access_blocker_flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
})
}
@@ -234,7 +233,6 @@ impl super::ControlPlaneApi for MockControlPlane {
allowed_ips: Arc::new(info.allowed_ips),
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
flags: info.access_blocker_flags,
rate_limits: info.rate_limits,
})
}

View File

@@ -10,7 +10,6 @@ use clashmap::ClashMap;
use tokio::time::Instant;
use tracing::{debug, info};
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::cache::endpoints::EndpointsCache;
@@ -23,6 +22,8 @@ use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
use crate::types::EndpointId;
use super::{EndpointAccessControl, RoleAccessControl};
#[non_exhaustive]
#[derive(Clone)]
pub enum ControlPlaneClient {

View File

@@ -227,35 +227,12 @@ pub(crate) struct UserFacingMessage {
#[derive(Deserialize)]
pub(crate) struct GetEndpointAccessControl {
pub(crate) role_secret: Box<str>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) account_id: Option<AccountIdInt>,
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) account_id: Option<AccountIdInt>,
pub(crate) block_public_connections: Option<bool>,
pub(crate) block_vpc_connections: Option<bool>,
#[serde(default)]
pub(crate) rate_limits: EndpointRateLimitConfig,
}
#[derive(Copy, Clone, Deserialize, Default)]
pub struct EndpointRateLimitConfig {
pub connection_attempts: ConnectionAttemptsLimit,
}
#[derive(Copy, Clone, Deserialize, Default)]
pub struct ConnectionAttemptsLimit {
pub tcp: Option<LeakyBucketSetting>,
pub ws: Option<LeakyBucketSetting>,
pub http: Option<LeakyBucketSetting>,
}
#[derive(Copy, Clone, Deserialize)]
pub struct LeakyBucketSetting {
pub rps: f64,
pub burst: f64,
}
/// Response which holds compute node's `host:port` pair.

View File

@@ -11,8 +11,6 @@ pub(crate) mod errors;
use std::sync::Arc;
use messages::EndpointRateLimitConfig;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
@@ -20,9 +18,8 @@ use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
use crate::intern::{AccountIdInt, ProjectIdInt};
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{compute, scram};
@@ -59,8 +56,6 @@ pub(crate) struct AuthInfo {
pub(crate) account_id: Option<AccountIdInt>,
/// Are public connections or VPC connections blocked?
pub(crate) access_blocker_flags: AccessBlockerFlags,
/// The rate limits for this endpoint.
pub(crate) rate_limits: EndpointRateLimitConfig,
}
/// Info for establishing a connection to a compute node.
@@ -106,8 +101,6 @@ pub struct EndpointAccessControl {
pub allowed_ips: Arc<Vec<IpPattern>>,
pub allowed_vpce: Arc<Vec<String>>,
pub flags: AccessBlockerFlags,
pub rate_limits: EndpointRateLimitConfig,
}
impl EndpointAccessControl {
@@ -146,36 +139,6 @@ impl EndpointAccessControl {
Ok(())
}
pub fn connection_attempt_rate_limit(
&self,
ctx: &RequestContext,
endpoint: &EndpointId,
rate_limiter: &EndpointRateLimiter,
) -> Result<(), AuthError> {
let endpoint = EndpointIdInt::from(endpoint);
let limits = &self.rate_limits.connection_attempts;
let config = match ctx.protocol() {
crate::metrics::Protocol::Http => limits.http,
crate::metrics::Protocol::Ws => limits.ws,
crate::metrics::Protocol::Tcp => limits.tcp,
crate::metrics::Protocol::SniRouter => return Ok(()),
};
let config = config.and_then(|config| {
if config.rps <= 0.0 || config.burst <= 0.0 {
return None;
}
Some(LeakyBucketConfig::new(config.rps, config.burst))
});
if !rate_limiter.check(endpoint, config, 1) {
return Err(AuthError::too_many_connections());
}
Ok(())
}
}
/// This will allocate per each call, but the http requests alone

View File

@@ -106,5 +106,4 @@ mod tls;
mod types;
mod url;
mod usage_metrics;
mod util;
mod waiters;

View File

@@ -8,19 +8,19 @@ use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "invalidate_cache", skip_all)]
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached();
if is_cached {
@@ -49,6 +49,14 @@ pub(crate) trait ConnectMechanism {
) -> Result<Self::Connection, Self::ConnectError>;
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
}
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
/// connect_to_compute concurrency lock
@@ -83,7 +91,7 @@ impl ConnectMechanism for TcpMechanism {
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,

View File

@@ -1,3 +1,4 @@
pub mod connect_compute;
pub mod copy_bidirectional;
pub mod handshake;
pub mod inprocess;

View File

@@ -1,10 +1,8 @@
#[cfg(test)]
mod tests;
pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::sync::Arc;
use futures::FutureExt;
@@ -23,16 +21,15 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
use crate::util::run_until_cancelled;
use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
@@ -49,6 +46,21 @@ impl ReportableError for TlsRequired {
impl UserFacingError for TlsRequired {}
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match futures::future::select(
std::pin::pin!(f),
std::pin::pin!(cancellation_token.cancelled()),
)
.await
{
futures::future::Either::Left((f, _)) => Some(f),
futures::future::Either::Right(((), _)) => None,
}
}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
@@ -122,12 +134,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,
@@ -346,12 +353,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
};
let (cplane, creds) = match user_info {
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
let creds = match &user_info {
auth::Backend::ControlPlane(_, creds) => creds,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute(
@@ -361,7 +368,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
auth: auth_info,
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info),
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::{Context, bail};
use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
@@ -19,13 +19,15 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
@@ -573,13 +575,19 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::Backend<'static, ComputeUserInfo> {
) -> auth::Backend<'static, ComputeCredentials> {
auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
client_key: [0; 32],
server_key: [0; 32],
})),
},
)
}

View File

@@ -1,4 +1,3 @@
use async_trait::async_trait;
use tracing::{error, info};
use crate::config::RetryConfig;
@@ -9,6 +8,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::proxy::retry::{retry_after, should_retry};
// Use macro to retain original callsite.
@@ -23,12 +23,7 @@ macro_rules! log_wake_compute_error {
};
}
#[async_trait]
pub(crate) trait WakeComputeBackend {
async fn wake_compute(&self, ctx: &RequestContext) -> Result<CachedNodeInfo, WakeComputeError>;
}
pub(crate) async fn wake_compute<B: WakeComputeBackend>(
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,
ctx: &RequestContext,
api: &B,

View File

@@ -69,8 +69,9 @@ pub struct LeakyBucketConfig {
pub max: f64,
}
#[cfg(test)]
impl LeakyBucketConfig {
pub fn new(rps: f64, max: f64) -> Self {
pub(crate) fn new(rps: f64, max: f64) -> Self {
assert!(rps > 0.0, "rps must be positive");
assert!(max > 0.0, "max must be positive");
Self { rps, max }

View File

@@ -12,10 +12,11 @@ use rand::{Rng, SeedableRng};
use tokio::time::{Duration, Instant};
use tracing::info;
use super::LeakyBucketConfig;
use crate::ext::LockExt;
use crate::intern::EndpointIdInt;
use super::LeakyBucketConfig;
pub struct GlobalRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,
@@ -139,12 +140,6 @@ impl RateBucketInfo {
Self::new(200, Duration::from_secs(600)),
];
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
pub const DEFAULT_REDIS_SET: [Self; 2] = [
Self::new(100_000, Duration::from_secs(1)),
Self::new(50_000, Duration::from_secs(10)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}

View File

@@ -2,11 +2,9 @@ use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider,
limiter: GlobalRateLimiter,
}
#[allow(async_fn_in_trait)]
@@ -27,11 +25,8 @@ impl Queryable for Cmd {
}
impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
client,
limiter: GlobalRateLimiter::new(info.into()),
}
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
Self { client }
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
@@ -49,11 +44,6 @@ impl RedisKVClient {
&mut self,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => {

View File

@@ -141,29 +141,19 @@ where
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
region_id: String,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
region_id: self.region_id.clone(),
}
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
Self { cache, region_id }
}
pub(crate) async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub(crate) async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
pub(crate) fn new(cache: Arc<C>) -> Self {
Self { cache }
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
@@ -276,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.increment_active_listeners().await;
handler.cache.increment_active_listeners().await;
conn
}
Err(e) => {
@@ -297,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
}
if cancellation_token.is_cancelled() {
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
return Ok(());
}
}
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
}
}
@@ -310,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let handler = MessageHandler::new(cache, region_id);
let handler = MessageHandler::new(cache);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));

View File

@@ -1,8 +1,5 @@
//! Definition and parser for channel binding flag (a part of the `GS2` header).
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
/// Channel binding flag (possibly with params).
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum ChannelBinding<T> {
@@ -58,7 +55,7 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
BASE64_STANDARD.encode(&cbind_input).into()
base64::encode(&cbind_input).into()
}
})
}
@@ -73,9 +70,9 @@ mod tests {
use ChannelBinding::*;
let cases = [
(NotSupportedClient, BASE64_STANDARD.encode("n,,")),
(NotSupportedServer, BASE64_STANDARD.encode("y,,")),
(Required("foo"), BASE64_STANDARD.encode("p=foo,,bar")),
(NotSupportedClient, base64::encode("n,,")),
(NotSupportedServer, base64::encode("y,,")),
(Required("foo"), base64::encode("p=foo,,bar")),
];
for (cb, input) in cases {

View File

@@ -2,8 +2,6 @@
use std::convert::Infallible;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use sha2::Sha256;
@@ -107,7 +105,7 @@ pub(crate) async fn exchange(
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {

View File

@@ -3,9 +3,6 @@
use std::fmt;
use std::ops::Range;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use super::base64_decode_array;
use super::key::{SCRAM_KEY_LEN, ScramKey};
use super::signature::SignatureBuilder;
@@ -91,7 +88,7 @@ impl<'a> ClientFirstMessage<'a> {
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
BASE64_STANDARD.encode_string(nonce, &mut message);
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
@@ -145,7 +142,11 @@ impl<'a> ClientFinalMessage<'a> {
server_key: &ScramKey,
) -> String {
let mut buf = String::from("v=");
BASE64_STANDARD.encode_string(signature_builder.build(server_key), &mut buf);
base64::encode_config_buf(
signature_builder.build(server_key),
base64::STANDARD,
&mut buf,
);
buf
}
@@ -250,7 +251,7 @@ mod tests {
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
);
assert_eq!(
BASE64_STANDARD.encode(msg.proof),
base64::encode(msg.proof),
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
);
}

View File

@@ -15,8 +15,6 @@ mod secret;
mod signature;
pub mod threadpool;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
pub(crate) use exchange::{Exchange, exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
@@ -34,7 +32,7 @@ pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
let mut bytes = [0u8; N];
let size = BASE64_STANDARD.decode_slice(input, &mut bytes).ok()?;
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
if size != N {
return None;
}

View File

@@ -1,7 +1,5 @@
//! Tools for SCRAM server secret management.
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use subtle::{Choice, ConstantTimeEq};
use super::base64_decode_array;
@@ -58,7 +56,7 @@ impl ServerSecret {
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: BASE64_STANDARD.encode(nonce),
salt_base64: base64::encode(nonce),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
@@ -90,7 +88,7 @@ mod tests {
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);
assert_eq!(base64::encode(parsed.stored_key), stored_key);
assert_eq!(base64::encode(parsed.server_key), server_key);
}
}

View File

@@ -21,7 +21,7 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -34,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
@@ -68,20 +68,17 @@ impl PoolingBackend {
self.config.authentication_config.is_vpc_acccess_proxy,
)?;
access_control.connection_attempt_rate_limit(
ctx,
&user_info.endpoint,
&self.endpoint_rate_limiter,
)?;
let ep = EndpointIdInt::from(&user_info.endpoint);
let rate_limit_config = None;
if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let role_access = backend.get_role_secret(ctx).await?;
let Some(secret) = role_access.secret else {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::password_failed(&*user_info.user));
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,
@@ -183,15 +180,14 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::pglb::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
keys: keys.keys,
},
&backend,
self.config.wake_compute_retry_config,
@@ -218,15 +214,18 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
crate::pglb::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
@@ -496,7 +495,6 @@ struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
@@ -522,10 +520,6 @@ impl ConnectMechanism for TokioMechanism {
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);

View File

@@ -16,8 +16,6 @@ use std::sync::atomic::AtomicUsize;
use std::task::{Poll, ready};
use std::time::Duration;
use base64::Engine as _;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use ed25519_dalek::{Signature, Signer, SigningKey};
use futures::Future;
use futures::future::poll_fn;
@@ -348,7 +346,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
jwt.push_str("eyJhbGciOiJFZERTQSJ9.");
// encode the jwt payload in-place
BASE64_URL_SAFE_NO_PAD.encode_string(payload, &mut jwt);
base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt);
// create the signature from the encoded header || payload
let sig: Signature = sk.sign(jwt.as_bytes());
@@ -356,7 +354,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
jwt.push('.');
// encode the jwt signature in-place
BASE64_URL_SAFE_NO_PAD.encode_string(sig.to_bytes(), &mut jwt);
base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt);
debug_assert_eq!(
jwt.len(),

View File

@@ -50,10 +50,10 @@ use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use crate::util::run_until_cancelled;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
@@ -417,12 +417,7 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
ctx.set_user_agent(
request
@@ -462,12 +457,7 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request

View File

@@ -41,11 +41,10 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::proxy::{NeonOptions, run_until_cancelled};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]

View File

@@ -3,8 +3,6 @@ pub mod postgres_rustls;
pub mod server_config;
use anyhow::Context;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use rustls::pki_types::CertificateDer;
use sha2::{Digest, Sha256};
use tracing::{error, info};
@@ -60,7 +58,7 @@ impl TlsServerEndPoint {
let oid = certificate.signature_algorithm.oid;
if SHA256_OIDS.contains(&oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(%subject, tls_server_end_point = %BASE64_STANDARD.encode(tls_server_end_point), "determined channel binding");
info!(%subject, tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(%subject, "unknown channel binding");

View File

@@ -1,14 +0,0 @@
use std::pin::pin;
use futures::future::{Either, select};
use tokio_util::sync::CancellationToken;
pub async fn run_until_cancelled<F: Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match select(pin!(f), pin!(cancellation_token.cancelled())).await {
Either::Left((f, _)) => Some(f),
Either::Right(((), _)) => None,
}
}

View File

@@ -9,7 +9,7 @@ pytest = "^7.4.4"
psycopg2-binary = "^2.9.10"
typing-extensions = "^4.12.2"
PyJWT = {version = "^2.1.0", extras = ["crypto"]}
requests = "^2.32.4"
requests = "^2.32.3"
pytest-xdist = "^3.3.1"
asyncpg = "^0.30.0"
aiopg = "^1.4.0"

View File

@@ -1398,31 +1398,6 @@ async fn handle_timeline_import(req: Request<Body>) -> Result<Response<Body>, Ap
)
}
async fn handle_tenant_timeline_locate(
service: Arc<Service>,
req: Request<Body>,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?;
check_permissions(&req, Scope::Admin)?;
maybe_rate_limit(&req, tenant_id).await;
match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(_req) => {}
};
json_response(
StatusCode::OK,
service
.tenant_timeline_locate(tenant_id, timeline_id)
.await?,
)
}
async fn handle_tenants_dump(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
@@ -2070,10 +2045,10 @@ pub fn make_router(
router
.data(Arc::new(HttpState::new(service, auth, build_info)))
// Non-prefixed generic endpoints (status, metrics, profiling)
.get("/metrics", |r| {
named_request_span(r, measured_metrics_handler, RequestName("metrics"))
})
// Non-prefixed generic endpoints (status, metrics, profiling)
.get("/status", |r| {
named_request_span(r, handle_status, RequestName("status"))
})
@@ -2164,16 +2139,6 @@ pub fn make_router(
)
},
)
.get(
"/debug/v1/tenant/:tenant_id/timeline/:timeline_id/locate",
|r| {
tenant_service_handler(
r,
handle_tenant_timeline_locate,
RequestName("v1_tenant_timeline_locate"),
)
},
)
.get("/debug/v1/scheduler", |r| {
named_request_span(r, handle_scheduler_dump, RequestName("debug_v1_scheduler"))
})

View File

@@ -141,11 +141,11 @@ pub(crate) struct StorageControllerMetricGroup {
measured::CounterVec<ReconcileLongRunningLabelGroupSet>,
/// Indicator of safekeeper reconciler queue depth, broken down by safekeeper, excluding ongoing reconciles.
pub(crate) storage_controller_safekeeper_reconciles_queued:
pub(crate) storage_controller_safkeeper_reconciles_queued:
measured::GaugeVec<SafekeeperReconcilerLabelGroupSet>,
/// Indicator of completed safekeeper reconciles, broken down by safekeeper.
pub(crate) storage_controller_safekeeper_reconciles_complete:
pub(crate) storage_controller_safkeeper_reconciles_complete:
measured::CounterVec<SafekeeperReconcilerLabelGroupSet>,
}

View File

@@ -2225,19 +2225,8 @@ impl Service {
&self,
reattach_req: ReAttachRequest,
) -> Result<ReAttachResponse, ApiError> {
let mut _node_lock: Option<TracingExclusiveGuard<NodeOperations>> = None;
if let Some(register_req) = reattach_req.register {
_node_lock = Some(
trace_exclusive_lock(
&self.node_op_locks,
register_req.node_id,
NodeOperations::Register,
)
.await,
);
self.node_register_with_lock(register_req, _node_lock.as_ref().unwrap())
.await?;
self.node_register(register_req).await?;
}
// Ordering: we must persist generation number updates before making them visible in the in-memory state
@@ -2278,7 +2267,6 @@ impl Service {
// fail, and start from scratch, so it doesn't make sense for us to try and preserve
// the stale/multi states at this point.
mode: LocationConfigMode::AttachedSingle,
stripe_size: shard.shard.stripe_size,
});
shard.generation = std::cmp::max(shard.generation, Some(new_gen));
@@ -2312,7 +2300,6 @@ impl Service {
id: *tenant_shard_id,
r#gen: None,
mode: LocationConfigMode::Secondary,
stripe_size: shard.shard.stripe_size,
});
// We must not update observed, because we have no guarantee that our
@@ -6662,8 +6649,6 @@ impl Service {
/// This is for debug/support only: assuming tenant data is already present in S3, we "create" a
/// tenant with a very high generation number so that it will see the existing data.
/// It does not create timelines on safekeepers, because they might already exist on some
/// safekeeper set. So, the timelines are not storcon-managed after the import.
pub(crate) async fn tenant_import(
&self,
tenant_id: TenantId,
@@ -7173,21 +7158,13 @@ impl Service {
&self,
register_req: NodeRegisterRequest,
) -> Result<(), ApiError> {
let node_lock = trace_exclusive_lock(
let _node_lock = trace_exclusive_lock(
&self.node_op_locks,
register_req.node_id,
NodeOperations::Register,
)
.await;
self.node_register_with_lock(register_req, &node_lock).await
}
async fn node_register_with_lock(
&self,
register_req: NodeRegisterRequest,
_node_lock: &TracingExclusiveGuard<NodeOperations>,
) -> Result<(), ApiError> {
#[derive(PartialEq)]
enum RegistrationStatus {
UpToDate,

View File

@@ -107,7 +107,7 @@ impl ChaosInjector {
// - Skip shards doing a graceful migration already, so that we allow these to run to
// completion rather than only exercising the first part and then cancelling with
// some other chaos.
matches!(shard.get_scheduling_policy(), ShardSchedulingPolicy::Active)
!matches!(shard.get_scheduling_policy(), ShardSchedulingPolicy::Active)
&& shard.get_preferred_node().is_none()
}

View File

@@ -230,7 +230,7 @@ impl ReconcilerHandle {
// increase it before putting into the queue.
let queued_gauge = &METRICS_REGISTRY
.metrics_group
.storage_controller_safekeeper_reconciles_queued;
.storage_controller_safkeeper_reconciles_queued;
let label_group = SafekeeperReconcilerLabelGroup {
sk_az: &sk_az,
sk_node_id: &sk_node_id,
@@ -306,7 +306,7 @@ impl SafekeeperReconciler {
let queued_gauge = &METRICS_REGISTRY
.metrics_group
.storage_controller_safekeeper_reconciles_queued;
.storage_controller_safkeeper_reconciles_queued;
queued_gauge.set(
SafekeeperReconcilerLabelGroup {
sk_az: &req.safekeeper.skp.availability_zone_id,
@@ -547,7 +547,7 @@ impl SafekeeperReconcilerInner {
let complete_counter = &METRICS_REGISTRY
.metrics_group
.storage_controller_safekeeper_reconciles_complete;
.storage_controller_safkeeper_reconciles_complete;
complete_counter.inc(SafekeeperReconcilerLabelGroup {
sk_az: &req.safekeeper.skp.availability_zone_id,
sk_node_id: &req.safekeeper.get_id().to_string(),

View File

@@ -17,7 +17,7 @@ use pageserver_api::controller_api::{
SafekeeperDescribeResponse, SkSchedulingPolicy, TimelineImportRequest,
};
use pageserver_api::models::{SafekeeperInfo, SafekeepersInfo, TimelineInfo};
use safekeeper_api::membership::{MemberSet, SafekeeperGeneration, SafekeeperId};
use safekeeper_api::membership::{MemberSet, SafekeeperId};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use utils::id::{NodeId, TenantId, TimelineId};
@@ -26,13 +26,6 @@ use utils::lsn::Lsn;
use super::Service;
#[derive(serde::Serialize, serde::Deserialize, Clone)]
pub struct TimelineLocateResponse {
pub generation: SafekeeperGeneration,
pub sk_set: Vec<NodeId>,
pub new_sk_set: Option<Vec<NodeId>>,
}
impl Service {
/// Timeline creation on safekeepers
///
@@ -403,38 +396,6 @@ impl Service {
Ok(())
}
/// Locate safekeepers for a timeline.
/// Return the generation, sk_set and new_sk_set if present.
/// If the timeline is not storcon-managed, return NotFound.
pub(crate) async fn tenant_timeline_locate(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<TimelineLocateResponse, ApiError> {
let timeline = self
.persistence
.get_timeline(tenant_id, timeline_id)
.await?;
let Some(timeline) = timeline else {
return Err(ApiError::NotFound(
anyhow::anyhow!("Timeline {}/{} not found", tenant_id, timeline_id).into(),
));
};
Ok(TimelineLocateResponse {
generation: SafekeeperGeneration::new(timeline.generation as u32),
sk_set: timeline
.sk_set
.iter()
.map(|id| NodeId(*id as u64))
.collect(),
new_sk_set: timeline
.new_sk_set
.map(|sk_set| sk_set.iter().map(|id| NodeId(*id as u64)).collect()),
})
}
/// Perform timeline deletion on safekeepers. Will return success: we persist the deletion into the reconciler.
pub(super) async fn tenant_timeline_delete_safekeepers(
self: &Arc<Self>,

View File

@@ -69,17 +69,15 @@ class EndpointHttpClient(requests.Session):
json: dict[str, str] = res.json()
return json
def prewarm_lfc(self, from_endpoint_id: str | None = None):
url: str = f"http://localhost:{self.external_port}/lfc/prewarm"
params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict()
self.post(url, params=params).raise_for_status()
def prewarm_lfc(self):
self.post(f"http://localhost:{self.external_port}/lfc/prewarm").raise_for_status()
def prewarmed():
json = self.prewarm_lfc_status()
status, err = json["status"], json.get("error")
assert status == "completed", f"{status}, error {err}"
wait_until(prewarmed, timeout=60)
wait_until(prewarmed)
def offload_lfc(self):
url = f"http://localhost:{self.external_port}/lfc/offload"

View File

@@ -129,18 +129,6 @@ class NeonAPI:
return cast("dict[str, Any]", resp.json())
def get_project_limits(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/limits",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def delete_project(
self,
project_id: str,

View File

@@ -2223,17 +2223,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
shards: list[dict[str, Any]] = body["shards"]
return shards
def timeline_locate(self, tenant_id: TenantId, timeline_id: TimelineId):
"""
:return: dict {"generation": int, "sk_set": [int], "new_sk_set": [int]}
"""
response = self.request(
"GET",
f"{self.api}/debug/v1/tenant/{tenant_id}/timeline/{timeline_id}/locate",
headers=self.headers(TokenScope.ADMIN),
)
return response.json()
def tenant_describe(self, tenant_id: TenantId):
"""
:return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr: str, "listen_http_port: int, preferred_az_id: str}
@@ -4057,16 +4046,6 @@ def static_proxy(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
)
vanilla_pg.stop()
vanilla_pg.edit_hba(
[
"local all all trust",
"host all all 127.0.0.1/32 scram-sha-256",
"host all all ::1/128 scram-sha-256",
]
)
vanilla_pg.start()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()

View File

@@ -45,8 +45,6 @@ class NeonEndpoint:
if self.branch.connect_env:
self.connect_env = self.branch.connect_env.copy()
self.connect_env["PGHOST"] = self.host
if self.type == "read_only":
self.project.read_only_endpoints_total += 1
def delete(self):
self.project.delete_endpoint(self.id)
@@ -230,13 +228,8 @@ class NeonProject:
self.benchmarks: dict[str, subprocess.Popen[Any]] = {}
self.restore_num: int = 0
self.restart_pgbench_on_console_errors: bool = False
self.limits: dict[str, Any] = self.get_limits()["limits"]
self.read_only_endpoints_total: int = 0
def get_limits(self) -> dict[str, Any]:
return self.neon_api.get_project_limits(self.id)
def delete(self) -> None:
def delete(self):
self.neon_api.delete_project(self.id)
def create_branch(self, parent_id: str | None = None) -> NeonBranch | None:
@@ -289,7 +282,6 @@ class NeonProject:
self.neon_api.delete_endpoint(self.id, endpoint_id)
self.endpoints[endpoint_id].branch.endpoints.pop(endpoint_id)
self.endpoints.pop(endpoint_id)
self.read_only_endpoints_total -= 1
self.wait()
def start_benchmark(self, target: str, clients: int = 10) -> subprocess.Popen[Any]:
@@ -377,64 +369,49 @@ def setup_class(
print(f"::warning::Retried on 524 error {neon_api.retries524} times")
if neon_api.retries4xx > 0:
print(f"::warning::Retried on 4xx error {neon_api.retries4xx} times")
log.info("Removing the project %s", project.id)
log.info("Removing the project")
project.delete()
def do_action(project: NeonProject, action: str) -> bool:
def do_action(project: NeonProject, action: str) -> None:
"""
Runs the action
"""
log.info("Action: %s", action)
if action == "new_branch":
log.info("Trying to create a new branch")
if 0 <= project.limits["max_branches"] <= len(project.branches):
log.info(
"Maximum branch limit exceeded (%s of %s)",
len(project.branches),
project.limits["max_branches"],
)
return False
parent = project.branches[
random.choice(list(set(project.branches.keys()) - project.reset_branches))
]
log.info("Parent: %s", parent)
child = parent.create_child_branch()
if child is None:
return False
return
log.info("Created branch %s", child)
child.start_benchmark()
elif action == "delete_branch":
if project.leaf_branches:
target: NeonBranch = random.choice(list(project.leaf_branches.values()))
target = random.choice(list(project.leaf_branches.values()))
log.info("Trying to delete branch %s", target)
target.delete()
else:
log.info("Leaf branches not found, skipping")
return False
elif action == "new_ro_endpoint":
if 0 <= project.limits["max_read_only_endpoints"] <= project.read_only_endpoints_total:
log.info(
"Maximum read only endpoint limit exceeded (%s of %s)",
project.read_only_endpoints_total,
project.limits["max_read_only_endpoints"],
)
return False
ep = random.choice(
[br for br in project.branches.values() if br.id not in project.reset_branches]
).create_ro_endpoint()
log.info("Created the RO endpoint with id %s branch: %s", ep.id, ep.branch.id)
ep.start_benchmark()
elif action == "delete_ro_endpoint":
if project.read_only_endpoints_total == 0:
log.info("no read_only endpoints present, skipping")
return False
ro_endpoints: list[NeonEndpoint] = [
endpoint for endpoint in project.endpoints.values() if endpoint.type == "read_only"
]
target_ep: NeonEndpoint = random.choice(ro_endpoints)
target_ep.delete()
log.info("endpoint %s deleted", target_ep.id)
if ro_endpoints:
target_ep: NeonEndpoint = random.choice(ro_endpoints)
target_ep.delete()
log.info("endpoint %s deleted", target_ep.id)
else:
log.info("no read_only endpoints present, skipping")
elif action == "restore_random_time":
if project.leaf_branches:
br: NeonBranch = random.choice(list(project.leaf_branches.values()))
@@ -442,10 +419,8 @@ def do_action(project: NeonProject, action: str) -> bool:
br.restore_random_time()
else:
log.info("No leaf branches found")
return False
else:
raise ValueError(f"The action {action} is unknown")
return True
@pytest.mark.timeout(7200)
@@ -482,9 +457,8 @@ def test_api_random(
pg_bin.run(["pgbench", "-i", "-I", "dtGvp", "-s100"], env=project.main_branch.connect_env)
for _ in range(num_operations):
log.info("Starting action #%s", _ + 1)
while not do_action(
do_action(
project, random.choices([a[0] for a in ACTIONS], weights=[w[1] for w in ACTIONS])[0]
):
log.info("Retrying...")
)
project.check_all_benchmarks()
assert True

View File

@@ -18,7 +18,6 @@ from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
PgBin,
Safekeeper,
flush_ep_to_pageserver,
)
from fixtures.pageserver.http import PageserverApiException
@@ -27,7 +26,6 @@ from fixtures.pageserver.utils import (
)
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage
from fixtures.safekeeper.http import MembershipConfiguration
from fixtures.workload import Workload
if TYPE_CHECKING:
@@ -544,24 +542,6 @@ def test_historic_storage_formats(
# All our artifacts should contain at least one timeline
assert len(timelines) > 0
# Import tenant does not create the timeline on safekeepers,
# because it is a debug handler and the timeline may have already been
# created on some set of safekeepers.
# Create the timeline on safekeepers manually.
# TODO(diko): when we have the script/storcon handler to migrate
# the timeline to storcon, we can replace this code with it.
mconf = MembershipConfiguration(
generation=1,
members=Safekeeper.sks_to_safekeeper_ids([env.safekeepers[0]]),
new_members=None,
)
members_sks = Safekeeper.mconf_sks(env, mconf)
for timeline in timelines:
Safekeeper.create_timeline(
dataset.tenant_id, timeline["timeline_id"], env.pageserver, mconf, members_sks
)
# TODO: ensure that the snapshots we're importing contain a sensible variety of content, at the very
# least they should include a mixture of deltas and image layers. Preferably they should also
# contain some "exotic" stuff like aux files from logical replication.

View File

@@ -188,8 +188,7 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet
pg_cur.execute("select pg_reload_conf()")
if query is LfcQueryMethod.COMPUTE_CTL:
# Same thing as prewarm_lfc(), testing other method
http_client.prewarm_lfc(endpoint.endpoint_id)
http_client.prewarm_lfc()
else:
pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,))

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: PgBin):
def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgBin):
env = neon_simple_env
env.create_branch("test_pageserver_restarts")
endpoint = env.endpoints.create_start("test_pageserver_restarts")
@@ -28,11 +28,7 @@ def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: Pg
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
pg_bin.run_capture(["pgbench", f"-T{n_restarts}", connstr])
thread = threading.Thread(
target=run_pgbench,
args=(endpoint.connstr(options="-cstatement_timeout=360s"),),
daemon=True,
)
thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True)
thread.start()
for _ in range(n_restarts):

View File

@@ -19,15 +19,11 @@ TABLE_NAME = "neon_control_plane.endpoints"
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
)
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
)
def check_cannot_connect(**kwargs):
@@ -64,9 +60,7 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
)
def query(status: int, query: str, *args):
@@ -81,8 +75,6 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
query(400, "select 1;") # ip address is not allowed
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'",
user="proxy",
password="password",
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
)
query(200, "select 1;") # should work now

View File

@@ -430,7 +430,6 @@ def test_tenant_delete_stale_shards(neon_env_builder: NeonEnvBuilder, pg_bin: Pg
workload.init()
workload.write_rows(256)
workload.validate()
workload.stop()
assert_prefix_not_empty(
neon_env_builder.pageserver_remote_storage,

View File

@@ -20,7 +20,8 @@ anstream = { version = "0.6" }
anyhow = { version = "1", features = ["backtrace"] }
axum = { version = "0.8", features = ["ws"] }
axum-core = { version = "0.5", default-features = false, features = ["tracing"] }
base64 = { version = "0.21" }
base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] }
base64-647d43efb71741da = { package = "base64", version = "0.21" }
base64ct = { version = "1", default-features = false, features = ["std"] }
bytes = { version = "1", features = ["serde"] }
camino = { version = "1", default-features = false, features = ["serde1"] }