Compare commits

..

8 Commits

Author SHA1 Message Date
Konstantin Knizhnik
de649f856c Fix documentation format issues 2024-04-13 22:37:39 +03:00
Konstantin Knizhnik
de3fdf9860 Add more comments 2024-04-13 21:47:01 +03:00
Konstantin Knizhnik
1b2cfc0259 Proivide comment for NeonRequest struct 2024-04-11 17:24:39 +03:00
Konstantin Knizhnik
165a1d7bf1 Make ruff happy 2024-04-11 09:15:35 +03:00
Konstantin Knizhnik
f07c33186a Add neon.protocol_version GUC 2024-04-11 09:15:35 +03:00
Konstantin Knizhnik
15c0e1351a Fix messages tags in PS serialize 2024-04-11 09:15:35 +03:00
Konstantin Knizhnik
ccbf95e9dc Use tags starting from 10 for command of new protocol 2024-04-11 09:15:34 +03:00
Konstantin Knizhnik
93e6046005 Send LSN range in getpage request 2024-04-11 09:15:31 +03:00
101 changed files with 3335 additions and 3506 deletions

View File

@@ -150,7 +150,7 @@ runs:
# Use aws s3 cp (instead of aws s3 sync) to keep files from previous runs to make old URLs work,
# and to keep files on the host to upload them to the database
time s5cmd --log error cp "${WORKDIR}/report/*" "s3://${BUCKET}/${REPORT_PREFIX}/${GITHUB_RUN_ID}/"
time aws s3 cp --recursive --only-show-errors "${WORKDIR}/report" "s3://${BUCKET}/${REPORT_PREFIX}/${GITHUB_RUN_ID}"
# Generate redirect
cat <<EOF > ${WORKDIR}/index.html

View File

@@ -18,7 +18,6 @@ on:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
cancel-in-progress: false
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -21,7 +21,6 @@ defaults:
concurrency:
group: build-build-tools-image-${{ inputs.image-tag }}
cancel-in-progress: false
# No permission for GITHUB_TOKEN by default; the **minimal required** set of permissions should be granted in each job.
permissions: {}

View File

@@ -1133,6 +1133,8 @@ jobs:
-f deployPreprodRegion=true
gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main \
-f deployPgSniRouter=false \
-f deployProxy=false \
-f deployStorage=true \
-f deployStorageBroker=true \
-f deployStorageController=true \

View File

@@ -28,9 +28,7 @@ jobs:
- name: Get build-tools image tag for the current commit
id: get-build-tools-tag
env:
# Usually, for COMMIT_SHA, we use `github.event.pull_request.head.sha || github.sha`, but here, even for PRs,
# we want to use `github.sha` i.e. point to a phantom merge commit to determine the image tag correctly.
COMMIT_SHA: ${{ github.sha }}
COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LAST_BUILD_TOOLS_SHA=$(

View File

@@ -20,7 +20,6 @@ defaults:
concurrency:
group: pin-build-tools-image-${{ inputs.from-tag }}
cancel-in-progress: false
permissions: {}

13
Cargo.lock generated
View File

@@ -2932,9 +2932,9 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "measured"
version = "0.0.21"
version = "0.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "652bc741286361c06de8cb4d89b21a6437f120c508c51713663589eeb9928ac5"
checksum = "3cbf033874bea03565f2449572c8640ca37ec26300455faf36001f24755da452"
dependencies = [
"bytes",
"crossbeam-utils",
@@ -2950,9 +2950,9 @@ dependencies = [
[[package]]
name = "measured-derive"
version = "0.0.21"
version = "0.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea497f33e1e856a376c32ad916f69a0bd3c597db1f912a399f842b01a4a685d"
checksum = "be9e29b682b38f8af2a89f960455054ab1a9f5a06822f6f3500637ad9fa57def"
dependencies = [
"heck 0.5.0",
"proc-macro2",
@@ -2962,9 +2962,9 @@ dependencies = [
[[package]]
name = "measured-process"
version = "0.0.21"
version = "0.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b364ccb66937a814b6b2ad751d1a2f7a9d5a78c761144036825fb36bb0771000"
checksum = "a20849acdd04c5d6a88f565559044546904648a1842a2937cfff0b48b4ca7ef2"
dependencies = [
"libc",
"measured",
@@ -4322,7 +4322,6 @@ dependencies = [
"itertools",
"lasso",
"md5",
"measured",
"metrics",
"native-tls",
"once_cell",

View File

@@ -107,8 +107,8 @@ lasso = "0.7"
leaky-bucket = "1.0.1"
libc = "0.2"
md5 = "0.7.0"
measured = { version = "0.0.21", features=["lasso"] }
measured-process = { version = "0.0.21" }
measured = { version = "0.0.20", features=["lasso"] }
measured-process = { version = "0.0.20" }
memoffset = "0.8"
native-tls = "0.2"
nix = { version = "0.27", features = ["fs", "process", "socket", "signal", "poll"] }

View File

@@ -58,12 +58,6 @@ RUN curl -fsSL "https://github.com/protocolbuffers/protobuf/releases/download/v$
&& mv protoc/include/google /usr/local/include/google \
&& rm -rf protoc.zip protoc
# s5cmd
ENV S5CMD_VERSION=2.2.2
RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/s5cmd_${S5CMD_VERSION}_Linux-$(uname -m | sed 's/x86_64/64bit/g' | sed 's/aarch64/arm64/g').tar.gz" | tar zxvf - s5cmd \
&& chmod +x s5cmd \
&& mv s5cmd /usr/local/bin/s5cmd
# LLVM
ENV LLVM_VERSION=17
RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \

View File

@@ -818,15 +818,9 @@ impl ComputeNode {
Client::connect(zenith_admin_connstr.as_str(), NoTls)
.context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?;
// Disable forwarding so that users don't get a cloud_admin role
let mut func = || {
client.simple_query("SET neon.forward_ddl = false")?;
client.simple_query("CREATE USER cloud_admin WITH SUPERUSER")?;
client.simple_query("GRANT zenith_admin TO cloud_admin")?;
Ok::<_, anyhow::Error>(())
};
func().context("apply_config setup cloud_admin")?;
client.simple_query("SET neon.forward_ddl = false")?;
client.simple_query("CREATE USER cloud_admin WITH SUPERUSER")?;
client.simple_query("GRANT zenith_admin TO cloud_admin")?;
drop(client);
// reconnect with connstring with expected name
@@ -838,29 +832,24 @@ impl ComputeNode {
};
// Disable DDL forwarding because control plane already knows about these roles/databases.
client
.simple_query("SET neon.forward_ddl = false")
.context("apply_config SET neon.forward_ddl = false")?;
client.simple_query("SET neon.forward_ddl = false")?;
// Proceed with post-startup configuration. Note, that order of operations is important.
let spec = &compute_state.pspec.as_ref().expect("spec must be set").spec;
create_neon_superuser(spec, &mut client).context("apply_config create_neon_superuser")?;
cleanup_instance(&mut client).context("apply_config cleanup_instance")?;
handle_roles(spec, &mut client).context("apply_config handle_roles")?;
handle_databases(spec, &mut client).context("apply_config handle_databases")?;
handle_role_deletions(spec, connstr.as_str(), &mut client)
.context("apply_config handle_role_deletions")?;
create_neon_superuser(spec, &mut client)?;
cleanup_instance(&mut client)?;
handle_roles(spec, &mut client)?;
handle_databases(spec, &mut client)?;
handle_role_deletions(spec, connstr.as_str(), &mut client)?;
handle_grants(
spec,
&mut client,
connstr.as_str(),
self.has_feature(ComputeFeature::AnonExtension),
)
.context("apply_config handle_grants")?;
handle_extensions(spec, &mut client).context("apply_config handle_extensions")?;
handle_extension_neon(&mut client).context("apply_config handle_extension_neon")?;
create_availability_check_data(&mut client)
.context("apply_config create_availability_check_data")?;
)?;
handle_extensions(spec, &mut client)?;
handle_extension_neon(&mut client)?;
create_availability_check_data(&mut client)?;
// 'Close' connection
drop(client);
@@ -868,7 +857,7 @@ impl ComputeNode {
// Run migrations separately to not hold up cold starts
thread::spawn(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_migrations(&mut client).context("apply_config handle_migrations")
handle_migrations(&mut client)
});
Ok(())
}

View File

@@ -2,7 +2,7 @@ use std::fs::File;
use std::path::Path;
use std::str::FromStr;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{anyhow, bail, Result};
use postgres::config::Config;
use postgres::{Client, NoTls};
use reqwest::StatusCode;
@@ -698,8 +698,7 @@ pub fn handle_grants(
// it is important to run this after all grants
if enable_anon_extension {
handle_extension_anon(spec, &db.owner, &mut db_client, false)
.context("handle_grants handle_extension_anon")?;
handle_extension_anon(spec, &db.owner, &mut db_client, false)?;
}
}
@@ -814,36 +813,28 @@ $$;"#,
// Add new migrations below.
];
let mut func = || {
let query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
client.simple_query(query)?;
let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
client.simple_query(query)?;
let query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)";
client.simple_query(query)?;
query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)";
client.simple_query(query)?;
let query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING";
client.simple_query(query)?;
query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING";
client.simple_query(query)?;
let query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin";
client.simple_query(query)?;
query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin";
client.simple_query(query)?;
let query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC";
client.simple_query(query)?;
Ok::<_, anyhow::Error>(())
};
func().context("handle_migrations prepare")?;
query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC";
client.simple_query(query)?;
let query = "SELECT id FROM neon_migration.migration_id";
let row = client
.query_one(query, &[])
.context("handle_migrations get migration_id")?;
query = "SELECT id FROM neon_migration.migration_id";
let row = client.query_one(query, &[])?;
let mut current_migration: usize = row.get::<&str, i64>("id") as usize;
let starting_migration_id = current_migration;
let query = "BEGIN";
client
.simple_query(query)
.context("handle_migrations begin")?;
query = "BEGIN";
client.simple_query(query)?;
while current_migration < migrations.len() {
let migration = &migrations[current_migration];
@@ -851,9 +842,7 @@ $$;"#,
info!("Skip migration id={}", current_migration);
} else {
info!("Running migration:\n{}\n", migration);
client.simple_query(migration).with_context(|| {
format!("handle_migrations current_migration={}", current_migration)
})?;
client.simple_query(migration)?;
}
current_migration += 1;
}
@@ -861,14 +850,10 @@ $$;"#,
"UPDATE neon_migration.migration_id SET id={}",
migrations.len()
);
client
.simple_query(&setval)
.context("handle_migrations update id")?;
client.simple_query(&setval)?;
let query = "COMMIT";
client
.simple_query(query)
.context("handle_migrations commit")?;
query = "COMMIT";
client.simple_query(query)?;
info!(
"Ran {} migrations",

View File

@@ -1231,7 +1231,7 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) {
match ComputeControlPlane::load(env.clone()) {
Ok(cplane) => {
for (_k, node) in cplane.endpoints {
if let Err(e) = node.stop(if immediate { "immediate" } else { "fast" }, false) {
if let Err(e) = node.stop(if immediate { "immediate" } else { "fast " }, false) {
eprintln!("postgres stop failed: {e:#}");
}
}
@@ -1417,7 +1417,6 @@ fn cli() -> Command {
.subcommand(
Command::new("timeline")
.about("Manage timelines")
.arg_required_else_help(true)
.subcommand(Command::new("list")
.about("List all timelines, available to this pageserver")
.arg(tenant_id_arg.clone()))

View File

@@ -156,7 +156,6 @@ pub struct SafekeeperConf {
pub remote_storage: Option<String>,
pub backup_threads: Option<u32>,
pub auth_enabled: bool,
pub listen_addr: Option<String>,
}
impl Default for SafekeeperConf {
@@ -170,7 +169,6 @@ impl Default for SafekeeperConf {
remote_storage: None,
backup_threads: None,
auth_enabled: false,
listen_addr: None,
}
}
}

View File

@@ -70,31 +70,24 @@ pub struct SafekeeperNode {
pub pg_connection_config: PgConnectionConfig,
pub env: LocalEnv,
pub http_client: reqwest::Client,
pub listen_addr: String,
pub http_base_url: String,
}
impl SafekeeperNode {
pub fn from_env(env: &LocalEnv, conf: &SafekeeperConf) -> SafekeeperNode {
let listen_addr = if let Some(ref listen_addr) = conf.listen_addr {
listen_addr.clone()
} else {
"127.0.0.1".to_string()
};
SafekeeperNode {
id: conf.id,
conf: conf.clone(),
pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port),
pg_connection_config: Self::safekeeper_connection_config(conf.pg_port),
env: env.clone(),
http_client: reqwest::Client::new(),
http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port),
listen_addr,
http_base_url: format!("http://127.0.0.1:{}/v1", conf.http_port),
}
}
/// Construct libpq connection string for connecting to this safekeeper.
fn safekeeper_connection_config(addr: &str, port: u16) -> PgConnectionConfig {
PgConnectionConfig::new_host_port(url::Host::parse(addr).unwrap(), port)
fn safekeeper_connection_config(port: u16) -> PgConnectionConfig {
PgConnectionConfig::new_host_port(url::Host::parse("127.0.0.1").unwrap(), port)
}
pub fn datadir_path_by_id(env: &LocalEnv, sk_id: NodeId) -> PathBuf {
@@ -118,8 +111,8 @@ impl SafekeeperNode {
);
io::stdout().flush().unwrap();
let listen_pg = format!("{}:{}", self.listen_addr, self.conf.pg_port);
let listen_http = format!("{}:{}", self.listen_addr, self.conf.http_port);
let listen_pg = format!("127.0.0.1:{}", self.conf.pg_port);
let listen_http = format!("127.0.0.1:{}", self.conf.http_port);
let id = self.id;
let datadir = self.datadir_path();
@@ -146,7 +139,7 @@ impl SafekeeperNode {
availability_zone,
];
if let Some(pg_tenant_only_port) = self.conf.pg_tenant_only_port {
let listen_pg_tenant_only = format!("{}:{}", self.listen_addr, pg_tenant_only_port);
let listen_pg_tenant_only = format!("127.0.0.1:{}", pg_tenant_only_port);
args.extend(["--listen-pg-tenant-only".to_owned(), listen_pg_tenant_only]);
}
if !self.conf.sync {

View File

@@ -7,19 +7,14 @@
//! use significantly less memory than this, but can only approximate the cardinality.
use std::{
hash::{BuildHasher, BuildHasherDefault, Hash},
sync::atomic::AtomicU8,
collections::HashMap,
hash::{BuildHasher, BuildHasherDefault, Hash, Hasher},
sync::{atomic::AtomicU8, Arc, RwLock},
};
use measured::{
label::{LabelGroupVisitor, LabelName, LabelValue, LabelVisitor},
metric::{
group::{Encoding, MetricValue},
name::MetricNameEncoder,
Metric, MetricType, MetricVec,
},
text::TextEncoder,
LabelGroup,
use prometheus::{
core::{self, Describer},
proto, Opts,
};
use twox_hash::xxh3;
@@ -98,25 +93,203 @@ macro_rules! register_hll {
/// ```
///
/// See <https://en.wikipedia.org/wiki/HyperLogLog#Practical_considerations> for estimates on alpha
pub type HyperLogLogVec<L, const N: usize> = MetricVec<HyperLogLogState<N>, L>;
pub type HyperLogLog<const N: usize> = Metric<HyperLogLogState<N>>;
pub struct HyperLogLogState<const N: usize> {
shards: [AtomicU8; N],
#[derive(Clone)]
pub struct HyperLogLogVec<const N: usize> {
core: Arc<HyperLogLogVecCore<N>>,
}
impl<const N: usize> Default for HyperLogLogState<N> {
fn default() -> Self {
#[allow(clippy::declare_interior_mutable_const)]
const ZERO: AtomicU8 = AtomicU8::new(0);
Self { shards: [ZERO; N] }
struct HyperLogLogVecCore<const N: usize> {
pub children: RwLock<HashMap<u64, HyperLogLog<N>, BuildHasherDefault<xxh3::Hash64>>>,
pub desc: core::Desc,
pub opts: Opts,
}
impl<const N: usize> core::Collector for HyperLogLogVec<N> {
fn desc(&self) -> Vec<&core::Desc> {
vec![&self.core.desc]
}
fn collect(&self) -> Vec<proto::MetricFamily> {
let mut m = proto::MetricFamily::default();
m.set_name(self.core.desc.fq_name.clone());
m.set_help(self.core.desc.help.clone());
m.set_field_type(proto::MetricType::GAUGE);
let mut metrics = Vec::new();
for child in self.core.children.read().unwrap().values() {
child.core.collect_into(&mut metrics);
}
m.set_metric(metrics);
vec![m]
}
}
impl<const N: usize> MetricType for HyperLogLogState<N> {
type Metadata = ();
impl<const N: usize> HyperLogLogVec<N> {
/// Create a new [`HyperLogLogVec`] based on the provided
/// [`Opts`] and partitioned by the given label names. At least one label name must be
/// provided.
pub fn new(opts: Opts, label_names: &[&str]) -> prometheus::Result<Self> {
assert!(N.is_power_of_two());
let variable_names = label_names.iter().map(|s| (*s).to_owned()).collect();
let opts = opts.variable_labels(variable_names);
let desc = opts.describe()?;
let v = HyperLogLogVecCore {
children: RwLock::new(HashMap::default()),
desc,
opts,
};
Ok(Self { core: Arc::new(v) })
}
/// `get_metric_with_label_values` returns the [`HyperLogLog<P>`] for the given slice
/// of label values (same order as the VariableLabels in Desc). If that combination of
/// label values is accessed for the first time, a new [`HyperLogLog<P>`] is created.
///
/// An error is returned if the number of label values is not the same as the
/// number of VariableLabels in Desc.
pub fn get_metric_with_label_values(
&self,
vals: &[&str],
) -> prometheus::Result<HyperLogLog<N>> {
self.core.get_metric_with_label_values(vals)
}
/// `with_label_values` works as `get_metric_with_label_values`, but panics if an error
/// occurs.
pub fn with_label_values(&self, vals: &[&str]) -> HyperLogLog<N> {
self.get_metric_with_label_values(vals).unwrap()
}
}
impl<const N: usize> HyperLogLogState<N> {
impl<const N: usize> HyperLogLogVecCore<N> {
pub fn get_metric_with_label_values(
&self,
vals: &[&str],
) -> prometheus::Result<HyperLogLog<N>> {
let h = self.hash_label_values(vals)?;
if let Some(metric) = self.children.read().unwrap().get(&h).cloned() {
return Ok(metric);
}
self.get_or_create_metric(h, vals)
}
pub(crate) fn hash_label_values(&self, vals: &[&str]) -> prometheus::Result<u64> {
if vals.len() != self.desc.variable_labels.len() {
return Err(prometheus::Error::InconsistentCardinality {
expect: self.desc.variable_labels.len(),
got: vals.len(),
});
}
let mut h = xxh3::Hash64::default();
for val in vals {
h.write(val.as_bytes());
}
Ok(h.finish())
}
fn get_or_create_metric(
&self,
hash: u64,
label_values: &[&str],
) -> prometheus::Result<HyperLogLog<N>> {
let mut children = self.children.write().unwrap();
// Check exist first.
if let Some(metric) = children.get(&hash).cloned() {
return Ok(metric);
}
let metric = HyperLogLog::with_opts_and_label_values(&self.opts, label_values)?;
children.insert(hash, metric.clone());
Ok(metric)
}
}
/// HLL is a probabilistic cardinality measure.
///
/// How to use this time-series for a metric name `my_metrics_total_hll`:
///
/// ```promql
/// # harmonic mean
/// 1 / (
/// sum (
/// 2 ^ -(
/// # HLL merge operation
/// max (my_metrics_total_hll{}) by (hll_shard, other_labels...)
/// )
/// ) without (hll_shard)
/// )
/// * alpha
/// * shards_count
/// * shards_count
/// ```
///
/// If you want an estimate over time, you can use the following query:
///
/// ```promql
/// # harmonic mean
/// 1 / (
/// sum (
/// 2 ^ -(
/// # HLL merge operation
/// max (
/// max_over_time(my_metrics_total_hll{}[$__rate_interval])
/// ) by (hll_shard, other_labels...)
/// )
/// ) without (hll_shard)
/// )
/// * alpha
/// * shards_count
/// * shards_count
/// ```
///
/// In the case of low cardinality, you might want to use the linear counting approximation:
///
/// ```promql
/// # LinearCounting(m, V) = m log (m / V)
/// shards_count * ln(shards_count /
/// # calculate V = how many shards contain a 0
/// count(max (proxy_connecting_endpoints{}) by (hll_shard, protocol) == 0) without (hll_shard)
/// )
/// ```
///
/// See <https://en.wikipedia.org/wiki/HyperLogLog#Practical_considerations> for estimates on alpha
#[derive(Clone)]
pub struct HyperLogLog<const N: usize> {
core: Arc<HyperLogLogCore<N>>,
}
impl<const N: usize> HyperLogLog<N> {
/// Create a [`HyperLogLog`] with the `name` and `help` arguments.
pub fn new<S1: Into<String>, S2: Into<String>>(name: S1, help: S2) -> prometheus::Result<Self> {
assert!(N.is_power_of_two());
let opts = Opts::new(name, help);
Self::with_opts(opts)
}
/// Create a [`HyperLogLog`] with the `opts` options.
pub fn with_opts(opts: Opts) -> prometheus::Result<Self> {
Self::with_opts_and_label_values(&opts, &[])
}
fn with_opts_and_label_values(opts: &Opts, label_values: &[&str]) -> prometheus::Result<Self> {
let desc = opts.describe()?;
let labels = make_label_pairs(&desc, label_values)?;
let v = HyperLogLogCore {
shards: [0; N].map(AtomicU8::new),
desc,
labels,
};
Ok(Self { core: Arc::new(v) })
}
pub fn measure(&self, item: &impl Hash) {
// changing the hasher will break compatibility with previous measurements.
self.record(BuildHasherDefault::<xxh3::Hash64>::default().hash_one(item));
@@ -126,11 +299,42 @@ impl<const N: usize> HyperLogLogState<N> {
let p = N.ilog2() as u8;
let j = hash & (N as u64 - 1);
let rho = (hash >> p).leading_zeros() as u8 + 1 - p;
self.shards[j as usize].fetch_max(rho, std::sync::atomic::Ordering::Relaxed);
self.core.shards[j as usize].fetch_max(rho, std::sync::atomic::Ordering::Relaxed);
}
}
struct HyperLogLogCore<const N: usize> {
shards: [AtomicU8; N],
desc: core::Desc,
labels: Vec<proto::LabelPair>,
}
impl<const N: usize> core::Collector for HyperLogLog<N> {
fn desc(&self) -> Vec<&core::Desc> {
vec![&self.core.desc]
}
fn take_sample(&self) -> [u8; N] {
self.shards.each_ref().map(|x| {
fn collect(&self) -> Vec<proto::MetricFamily> {
let mut m = proto::MetricFamily::default();
m.set_name(self.core.desc.fq_name.clone());
m.set_help(self.core.desc.help.clone());
m.set_field_type(proto::MetricType::GAUGE);
let mut metrics = Vec::new();
self.core.collect_into(&mut metrics);
m.set_metric(metrics);
vec![m]
}
}
impl<const N: usize> HyperLogLogCore<N> {
fn collect_into(&self, metrics: &mut Vec<proto::Metric>) {
self.shards.iter().enumerate().for_each(|(i, x)| {
let mut shard_label = proto::LabelPair::default();
shard_label.set_name("hll_shard".to_owned());
shard_label.set_value(format!("{i}"));
// We reset the counter to 0 so we can perform a cardinality measure over any time slice in prometheus.
// This seems like it would be a race condition,
@@ -140,90 +344,85 @@ impl<const N: usize> HyperLogLogState<N> {
// TODO: maybe we shouldn't reset this on every collect, instead, only after a time window.
// this would mean that a dev port-forwarding the metrics url won't break the sampling.
x.swap(0, std::sync::atomic::Ordering::Relaxed)
let v = x.swap(0, std::sync::atomic::Ordering::Relaxed);
let mut m = proto::Metric::default();
let mut c = proto::Gauge::default();
c.set_value(v as f64);
m.set_gauge(c);
let mut labels = Vec::with_capacity(self.labels.len() + 1);
labels.extend_from_slice(&self.labels);
labels.push(shard_label);
m.set_label(labels);
metrics.push(m);
})
}
}
impl<W: std::io::Write, const N: usize> measured::metric::MetricEncoding<TextEncoder<W>>
for HyperLogLogState<N>
{
fn write_type(
name: impl MetricNameEncoder,
enc: &mut TextEncoder<W>,
) -> Result<(), std::io::Error> {
enc.write_type(&name, measured::text::MetricType::Gauge)
fn make_label_pairs(
desc: &core::Desc,
label_values: &[&str],
) -> prometheus::Result<Vec<proto::LabelPair>> {
if desc.variable_labels.len() != label_values.len() {
return Err(prometheus::Error::InconsistentCardinality {
expect: desc.variable_labels.len(),
got: label_values.len(),
});
}
fn collect_into(
&self,
_: &(),
labels: impl LabelGroup,
name: impl MetricNameEncoder,
enc: &mut TextEncoder<W>,
) -> Result<(), std::io::Error> {
struct I64(i64);
impl LabelValue for I64 {
fn visit<V: LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0)
}
}
struct HllShardLabel {
hll_shard: i64,
}
impl LabelGroup for HllShardLabel {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const LE: &LabelName = LabelName::from_str("hll_shard");
v.write_value(LE, &I64(self.hll_shard));
}
}
self.take_sample()
.into_iter()
.enumerate()
.try_for_each(|(hll_shard, val)| {
enc.write_metric_value(
name.by_ref(),
labels.by_ref().compose_with(HllShardLabel {
hll_shard: hll_shard as i64,
}),
MetricValue::Int(val as i64),
)
})
let total_len = desc.variable_labels.len() + desc.const_label_pairs.len();
if total_len == 0 {
return Ok(vec![]);
}
if desc.variable_labels.is_empty() {
return Ok(desc.const_label_pairs.clone());
}
let mut label_pairs = Vec::with_capacity(total_len);
for (i, n) in desc.variable_labels.iter().enumerate() {
let mut label_pair = proto::LabelPair::default();
label_pair.set_name(n.clone());
label_pair.set_value(label_values[i].to_owned());
label_pairs.push(label_pair);
}
for label_pair in &desc.const_label_pairs {
label_pairs.push(label_pair.clone());
}
label_pairs.sort();
Ok(label_pairs)
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use measured::{label::StaticLabelSet, FixedCardinalityLabel};
use prometheus::{proto, Opts};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::{Distribution, Zipf};
use crate::HyperLogLogVec;
#[derive(FixedCardinalityLabel, Clone, Copy)]
#[label(singleton = "x")]
enum Label {
A,
B,
fn collect(hll: &HyperLogLogVec<32>) -> Vec<proto::Metric> {
let mut metrics = vec![];
hll.core
.children
.read()
.unwrap()
.values()
.for_each(|c| c.core.collect_into(&mut metrics));
metrics
}
fn collect(hll: &HyperLogLogVec<StaticLabelSet<Label>, 32>) -> ([u8; 32], [u8; 32]) {
// cannot go through the `hll.collect_family_into` interface yet...
// need to see if I can fix the conflicting impls problem in measured.
(
hll.get_metric(hll.with_labels(Label::A)).take_sample(),
hll.get_metric(hll.with_labels(Label::B)).take_sample(),
)
}
fn get_cardinality(samples: &[[u8; 32]]) -> f64 {
fn get_cardinality(metrics: &[proto::Metric], filter: impl Fn(&proto::Metric) -> bool) -> f64 {
let mut buckets = [0.0; 32];
for &sample in samples {
for (i, m) in sample.into_iter().enumerate() {
buckets[i] = f64::max(buckets[i], m as f64);
for metric in metrics.chunks_exact(32) {
if filter(&metric[0]) {
for (i, m) in metric.iter().enumerate() {
buckets[i] = f64::max(buckets[i], m.get_gauge().get_value());
}
}
}
@@ -238,7 +437,7 @@ mod tests {
}
fn test_cardinality(n: usize, dist: impl Distribution<f64>) -> ([usize; 3], [f64; 3]) {
let hll = HyperLogLogVec::<StaticLabelSet<Label>, 32>::new();
let hll = HyperLogLogVec::<32>::new(Opts::new("foo", "bar"), &["x"]).unwrap();
let mut iter = StdRng::seed_from_u64(0x2024_0112).sample_iter(dist);
let mut set_a = HashSet::new();
@@ -246,20 +445,18 @@ mod tests {
for x in iter.by_ref().take(n) {
set_a.insert(x.to_bits());
hll.get_metric(hll.with_labels(Label::A))
.measure(&x.to_bits());
hll.with_label_values(&["a"]).measure(&x.to_bits());
}
for x in iter.by_ref().take(n) {
set_b.insert(x.to_bits());
hll.get_metric(hll.with_labels(Label::B))
.measure(&x.to_bits());
hll.with_label_values(&["b"]).measure(&x.to_bits());
}
let merge = &set_a | &set_b;
let (a, b) = collect(&hll);
let len = get_cardinality(&[a, b]);
let len_a = get_cardinality(&[a]);
let len_b = get_cardinality(&[b]);
let metrics = collect(&hll);
let len = get_cardinality(&metrics, |_| true);
let len_a = get_cardinality(&metrics, |l| l.get_label()[0].get_value() == "a");
let len_b = get_cardinality(&metrics, |l| l.get_label()[0].get_value() == "b");
([merge.len(), set_a.len(), set_b.len()], [len, len_a, len_b])
}

View File

@@ -5,7 +5,7 @@
#![deny(clippy::undocumented_unsafe_blocks)]
use measured::{
label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels},
label::{LabelGroupVisitor, LabelName, NoLabels},
metric::{
counter::CounterState,
gauge::GaugeState,
@@ -40,7 +40,7 @@ pub mod launch_timestamp;
mod wrappers;
pub use wrappers::{CountedReader, CountedWriter};
mod hll;
pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec};
pub use hll::{HyperLogLog, HyperLogLogVec};
#[cfg(target_os = "linux")]
pub mod more_process_metrics;
@@ -421,171 +421,3 @@ pub type IntCounterPair = GenericCounterPair<AtomicU64>;
/// A guard for [`IntCounterPair`] that will decrement the gauge on drop
pub type IntCounterPairGuard = GenericCounterPairGuard<AtomicU64>;
pub trait CounterPairAssoc {
const INC_NAME: &'static MetricName;
const DEC_NAME: &'static MetricName;
const INC_HELP: &'static str;
const DEC_HELP: &'static str;
type LabelGroupSet: LabelGroupSet;
}
pub struct CounterPairVec<A: CounterPairAssoc> {
vec: measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
}
impl<A: CounterPairAssoc> Default for CounterPairVec<A>
where
A::LabelGroupSet: Default,
{
fn default() -> Self {
Self {
vec: Default::default(),
}
}
}
impl<A: CounterPairAssoc> CounterPairVec<A> {
pub fn guard(
&self,
labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>,
) -> MeasuredCounterPairGuard<'_, A> {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).inc.inc();
MeasuredCounterPairGuard { vec: &self.vec, id }
}
pub fn inc(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).inc.inc();
}
pub fn dec(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).dec.inc();
}
pub fn remove_metric(
&self,
labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>,
) -> Option<MeasuredCounterPairState> {
let id = self.vec.with_labels(labels);
self.vec.remove_metric(id)
}
}
impl<T, A> ::measured::metric::group::MetricGroup<T> for CounterPairVec<A>
where
T: ::measured::metric::group::Encoding,
A: CounterPairAssoc,
::measured::metric::counter::CounterState: ::measured::metric::MetricEncoding<T>,
{
fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> {
// write decrement first to avoid a race condition where inc - dec < 0
T::write_help(enc, A::DEC_NAME, A::DEC_HELP)?;
self.vec
.collect_family_into(A::DEC_NAME, &mut Dec(&mut *enc))?;
T::write_help(enc, A::INC_NAME, A::INC_HELP)?;
self.vec
.collect_family_into(A::INC_NAME, &mut Inc(&mut *enc))?;
Ok(())
}
}
#[derive(MetricGroup, Default)]
pub struct MeasuredCounterPairState {
pub inc: CounterState,
pub dec: CounterState,
}
impl measured::metric::MetricType for MeasuredCounterPairState {
type Metadata = ();
}
pub struct MeasuredCounterPairGuard<'a, A: CounterPairAssoc> {
vec: &'a measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
id: measured::metric::LabelId<A::LabelGroupSet>,
}
impl<A: CounterPairAssoc> Drop for MeasuredCounterPairGuard<'_, A> {
fn drop(&mut self) {
self.vec.get_metric(self.id).dec.inc();
}
}
/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the inc counter to the inner encoder.
struct Inc<T>(T);
/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the dec counter to the inner encoder.
struct Dec<T>(T);
impl<T: Encoding> Encoding for Inc<T> {
type Err = T::Err;
fn write_help(&mut self, name: impl MetricNameEncoder, help: &str) -> Result<(), Self::Err> {
self.0.write_help(name, help)
}
fn write_metric_value(
&mut self,
name: impl MetricNameEncoder,
labels: impl LabelGroup,
value: MetricValue,
) -> Result<(), Self::Err> {
self.0.write_metric_value(name, labels, value)
}
}
impl<T: Encoding> MetricEncoding<Inc<T>> for MeasuredCounterPairState
where
CounterState: MetricEncoding<T>,
{
fn write_type(name: impl MetricNameEncoder, enc: &mut Inc<T>) -> Result<(), T::Err> {
CounterState::write_type(name, &mut enc.0)
}
fn collect_into(
&self,
metadata: &(),
labels: impl LabelGroup,
name: impl MetricNameEncoder,
enc: &mut Inc<T>,
) -> Result<(), T::Err> {
self.inc.collect_into(metadata, labels, name, &mut enc.0)
}
}
impl<T: Encoding> Encoding for Dec<T> {
type Err = T::Err;
fn write_help(&mut self, name: impl MetricNameEncoder, help: &str) -> Result<(), Self::Err> {
self.0.write_help(name, help)
}
fn write_metric_value(
&mut self,
name: impl MetricNameEncoder,
labels: impl LabelGroup,
value: MetricValue,
) -> Result<(), Self::Err> {
self.0.write_metric_value(name, labels, value)
}
}
/// Write the dec counter to the encoder
impl<T: Encoding> MetricEncoding<Dec<T>> for MeasuredCounterPairState
where
CounterState: MetricEncoding<T>,
{
fn write_type(name: impl MetricNameEncoder, enc: &mut Dec<T>) -> Result<(), T::Err> {
CounterState::write_type(name, &mut enc.0)
}
fn collect_into(
&self,
metadata: &(),
labels: impl LabelGroup,
name: impl MetricNameEncoder,
enc: &mut Dec<T>,
) -> Result<(), T::Err> {
self.dec.collect_into(metadata, labels, name, &mut enc.0)
}
}

View File

@@ -747,18 +747,10 @@ pub struct TimelineGcRequest {
pub gc_horizon: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalRedoManagerProcessStatus {
pub pid: u32,
/// The strum-generated `into::<&'static str>()` for `pageserver::walredo::ProcessKind`.
/// `ProcessKind` are a transitory thing, so, they have no enum representation in `pageserver_api`.
pub kind: Cow<'static, str>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalRedoManagerStatus {
pub last_redo_at: Option<chrono::DateTime<chrono::Utc>>,
pub process: Option<WalRedoManagerProcessStatus>,
pub pid: Option<u32>,
}
/// The progress of a secondary tenant is mostly useful when doing a long running download: e.g. initiating
@@ -849,21 +841,21 @@ impl TryFrom<u8> for PagestreamBeMessageTag {
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamExistsRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub rel: RelTag,
}
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamNblocksRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub rel: RelTag,
}
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamGetPageRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub rel: RelTag,
pub blkno: u32,
@@ -871,14 +863,14 @@ pub struct PagestreamGetPageRequest {
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamDbSizeRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub dbnode: u32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamGetSlruSegmentRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub kind: u8,
pub segno: u32,
@@ -931,8 +923,8 @@ impl PagestreamFeMessage {
match self {
Self::Exists(req) => {
bytes.put_u8(0);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(10);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.rel.spcnode);
bytes.put_u32(req.rel.dbnode);
@@ -941,8 +933,8 @@ impl PagestreamFeMessage {
}
Self::Nblocks(req) => {
bytes.put_u8(1);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(11);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.rel.spcnode);
bytes.put_u32(req.rel.dbnode);
@@ -951,8 +943,8 @@ impl PagestreamFeMessage {
}
Self::GetPage(req) => {
bytes.put_u8(2);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(12);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.rel.spcnode);
bytes.put_u32(req.rel.dbnode);
@@ -962,15 +954,15 @@ impl PagestreamFeMessage {
}
Self::DbSize(req) => {
bytes.put_u8(3);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(13);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.dbnode);
}
Self::GetSlruSegment(req) => {
bytes.put_u8(4);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(14);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u8(req.kind);
bytes.put_u32(req.segno);
@@ -987,11 +979,32 @@ impl PagestreamFeMessage {
//
// TODO: consider using protobuf or serde bincode for less error prone
// serialization.
let msg_tag = body.read_u8()?;
let mut msg_tag = body.read_u8()?;
//
// Old version of protocol use commands with tags started with 0 and containing `latest` flag.
// New version of protocol shift command tags by 10 and pass LSN range instead of `latest` flag.
// Server should be able to handle both protocol version. As far as we are not passing no=w,
// protocol version from client to server, we make a decision based on tag range.
// So this code actually provides backward compatibility.
//
let horizon = if msg_tag >= 10 {
// new protocol
msg_tag -= 10; // commands tags in new protocol starts with 10
Lsn::from(body.read_u64::<BigEndian>()?)
} else {
// old_protocol
let latest = body.read_u8()? != 0;
if latest {
Lsn::MAX // get latest version
} else {
Lsn::INVALID // get version on specified LSN
}
};
let lsn = Lsn::from(body.read_u64::<BigEndian>()?);
match msg_tag {
0 => Ok(PagestreamFeMessage::Exists(PagestreamExistsRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
rel: RelTag {
spcnode: body.read_u32::<BigEndian>()?,
dbnode: body.read_u32::<BigEndian>()?,
@@ -1000,8 +1013,8 @@ impl PagestreamFeMessage {
},
})),
1 => Ok(PagestreamFeMessage::Nblocks(PagestreamNblocksRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
rel: RelTag {
spcnode: body.read_u32::<BigEndian>()?,
dbnode: body.read_u32::<BigEndian>()?,
@@ -1010,8 +1023,8 @@ impl PagestreamFeMessage {
},
})),
2 => Ok(PagestreamFeMessage::GetPage(PagestreamGetPageRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
rel: RelTag {
spcnode: body.read_u32::<BigEndian>()?,
dbnode: body.read_u32::<BigEndian>()?,
@@ -1021,14 +1034,14 @@ impl PagestreamFeMessage {
blkno: body.read_u32::<BigEndian>()?,
})),
3 => Ok(PagestreamFeMessage::DbSize(PagestreamDbSizeRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
dbnode: body.read_u32::<BigEndian>()?,
})),
4 => Ok(PagestreamFeMessage::GetSlruSegment(
PagestreamGetSlruSegmentRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
kind: body.read_u8()?,
segno: body.read_u32::<BigEndian>()?,
},
@@ -1156,7 +1169,7 @@ mod tests {
// Test serialization/deserialization of PagestreamFeMessage
let messages = vec![
PagestreamFeMessage::Exists(PagestreamExistsRequest {
latest: true,
horizon: Lsn::MAX,
lsn: Lsn(4),
rel: RelTag {
forknum: 1,
@@ -1166,7 +1179,7 @@ mod tests {
},
}),
PagestreamFeMessage::Nblocks(PagestreamNblocksRequest {
latest: false,
horizon: Lsn::INVALID,
lsn: Lsn(4),
rel: RelTag {
forknum: 1,
@@ -1176,8 +1189,8 @@ mod tests {
},
}),
PagestreamFeMessage::GetPage(PagestreamGetPageRequest {
latest: true,
lsn: Lsn(4),
horizon: Lsn::MAX,
lsn: Lsn::INVALID,
rel: RelTag {
forknum: 1,
spcnode: 2,
@@ -1187,7 +1200,7 @@ mod tests {
blkno: 7,
}),
PagestreamFeMessage::DbSize(PagestreamDbSizeRequest {
latest: true,
horizon: Lsn::MAX,
lsn: Lsn(4),
dbnode: 7,
}),

View File

@@ -8,89 +8,12 @@ use hex::FromHex;
use serde::{Deserialize, Serialize};
use utils::id::TenantId;
/// See docs/rfcs/031-sharding-static.md for an overview of sharding.
///
/// This module contains a variety of types used to represent the concept of sharding
/// a Neon tenant across multiple physical shards. Since there are quite a few of these,
/// we provide an summary here.
///
/// Types used to describe shards:
/// - [`ShardCount`] describes how many shards make up a tenant, plus the magic `unsharded` value
/// which identifies a tenant which is not shard-aware. This means its storage paths do not include
/// a shard suffix.
/// - [`ShardNumber`] is simply the zero-based index of a shard within a tenant.
/// - [`ShardIndex`] is the 2-tuple of `ShardCount` and `ShardNumber`, it's just like a `TenantShardId`
/// without the tenant ID. This is useful for things that are implicitly scoped to a particular
/// tenant, such as layer files.
/// - [`ShardIdentity`]` is the full description of a particular shard's parameters, in sufficient
/// detail to convert a [`Key`] to a [`ShardNumber`] when deciding where to write/read.
/// - The [`ShardSlug`] is a terse formatter for ShardCount and ShardNumber, written as
/// four hex digits. An unsharded tenant is `0000`.
/// - [`TenantShardId`] is the unique ID of a particular shard within a particular tenant
///
/// Types used to describe the parameters for data distribution in a sharded tenant:
/// - [`ShardStripeSize`] controls how long contiguous runs of [`Key`]s (stripes) are when distributed across
/// multiple shards. Its value is given in 8kiB pages.
/// - [`ShardLayout`] describes the data distribution scheme, and at time of writing is
/// always zero: this is provided for future upgrades that might introduce different
/// data distribution schemes.
///
/// Examples:
/// - A legacy unsharded tenant has one shard with ShardCount(0), ShardNumber(0), and its slug is 0000
/// - A single sharded tenant has one shard with ShardCount(1), ShardNumber(0), and its slug is 0001
/// - In a tenant with 4 shards, each shard has ShardCount(N), ShardNumber(i) where i in 0..N-1 (inclusive),
/// and their slugs are 0004, 0104, 0204, and 0304.
#[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug, Hash)]
pub struct ShardNumber(pub u8);
#[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug, Hash)]
pub struct ShardCount(u8);
/// Combination of ShardNumber and ShardCount. For use within the context of a particular tenant,
/// when we need to know which shard we're dealing with, but do not need to know the full
/// ShardIdentity (because we won't be doing any page->shard mapping), and do not need to know
/// the fully qualified TenantShardId.
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct ShardIndex {
pub shard_number: ShardNumber,
pub shard_count: ShardCount,
}
/// The ShardIdentity contains enough information to map a [`Key`] to a [`ShardNumber`],
/// and to check whether that [`ShardNumber`] is the same as the current shard.
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardIdentity {
pub number: ShardNumber,
pub count: ShardCount,
pub stripe_size: ShardStripeSize,
layout: ShardLayout,
}
/// Formatting helper, for generating the `shard_id` label in traces.
struct ShardSlug<'a>(&'a TenantShardId);
/// TenantShardId globally identifies a particular shard in a particular tenant.
///
/// These are written as `<TenantId>-<ShardSlug>`, for example:
/// # The second shard in a two-shard tenant
/// 072f1291a5310026820b2fe4b2968934-0102
///
/// If the `ShardCount` is _unsharded_, the `TenantShardId` is written without
/// a shard suffix and is equivalent to the encoding of a `TenantId`: this enables
/// an unsharded [`TenantShardId`] to be used interchangably with a [`TenantId`].
///
/// The human-readable encoding of an unsharded TenantShardId, such as used in API URLs,
/// is both forward and backward compatible with TenantId: a legacy TenantId can be
/// decoded as a TenantShardId, and when re-encoded it will be parseable
/// as a TenantId.
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct TenantShardId {
pub tenant_id: TenantId,
pub shard_number: ShardNumber,
pub shard_count: ShardCount,
}
impl ShardCount {
pub const MAX: Self = Self(u8::MAX);
@@ -115,7 +38,6 @@ impl ShardCount {
self.0
}
///
pub fn is_unsharded(&self) -> bool {
self.0 == 0
}
@@ -131,6 +53,33 @@ impl ShardNumber {
pub const MAX: Self = Self(u8::MAX);
}
/// TenantShardId identify the units of work for the Pageserver.
///
/// These are written as `<tenant_id>-<shard number><shard-count>`, for example:
///
/// # The second shard in a two-shard tenant
/// 072f1291a5310026820b2fe4b2968934-0102
///
/// Historically, tenants could not have multiple shards, and were identified
/// by TenantId. To support this, TenantShardId has a special legacy
/// mode where `shard_count` is equal to zero: this represents a single-sharded
/// tenant which should be written as a TenantId with no suffix.
///
/// The human-readable encoding of TenantShardId, such as used in API URLs,
/// is both forward and backward compatible: a legacy TenantId can be
/// decoded as a TenantShardId, and when re-encoded it will be parseable
/// as a TenantId.
///
/// Note that the binary encoding is _not_ backward compatible, because
/// at the time sharding is introduced, there are no existing binary structures
/// containing TenantId that we need to handle.
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct TenantShardId {
pub tenant_id: TenantId,
pub shard_number: ShardNumber,
pub shard_count: ShardCount,
}
impl TenantShardId {
pub fn unsharded(tenant_id: TenantId) -> Self {
Self {
@@ -162,13 +111,10 @@ impl TenantShardId {
}
/// Convenience for code that has special behavior on the 0th shard.
pub fn is_shard_zero(&self) -> bool {
pub fn is_zero(&self) -> bool {
self.shard_number == ShardNumber(0)
}
/// The "unsharded" value is distinct from simply having a single shard: it represents
/// a tenant which is not shard-aware at all, and whose storage paths will not include
/// a shard suffix.
pub fn is_unsharded(&self) -> bool {
self.shard_number == ShardNumber(0) && self.shard_count.is_unsharded()
}
@@ -204,6 +150,9 @@ impl TenantShardId {
}
}
/// Formatting helper
struct ShardSlug<'a>(&'a TenantShardId);
impl<'a> std::fmt::Display for ShardSlug<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
@@ -273,6 +222,16 @@ impl From<[u8; 18]> for TenantShardId {
}
}
/// For use within the context of a particular tenant, when we need to know which
/// shard we're dealing with, but do not need to know the full ShardIdentity (because
/// we won't be doing any page->shard mapping), and do not need to know the fully qualified
/// TenantShardId.
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct ShardIndex {
pub shard_number: ShardNumber,
pub shard_count: ShardCount,
}
impl ShardIndex {
pub fn new(number: ShardNumber, count: ShardCount) -> Self {
Self {
@@ -287,9 +246,6 @@ impl ShardIndex {
}
}
/// The "unsharded" value is distinct from simply having a single shard: it represents
/// a tenant which is not shard-aware at all, and whose storage paths will not include
/// a shard suffix.
pub fn is_unsharded(&self) -> bool {
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
}
@@ -357,8 +313,6 @@ impl Serialize for TenantShardId {
if serializer.is_human_readable() {
serializer.collect_str(self)
} else {
// Note: while human encoding of [`TenantShardId`] is backward and forward
// compatible, this binary encoding is not.
let mut packed: [u8; 18] = [0; 18];
packed[0..16].clone_from_slice(&self.tenant_id.as_arr());
packed[16] = self.shard_number.0;
@@ -436,6 +390,16 @@ const LAYOUT_BROKEN: ShardLayout = ShardLayout(255);
/// Default stripe size in pages: 256MiB divided by 8kiB page size.
const DEFAULT_STRIPE_SIZE: ShardStripeSize = ShardStripeSize(256 * 1024 / 8);
/// The ShardIdentity contains the information needed for one member of map
/// to resolve a key to a shard, and then check whether that shard is ==self.
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
pub struct ShardIdentity {
pub number: ShardNumber,
pub count: ShardCount,
pub stripe_size: ShardStripeSize,
layout: ShardLayout,
}
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum ShardConfigError {
#[error("Invalid shard count")]
@@ -475,9 +439,6 @@ impl ShardIdentity {
}
}
/// The "unsharded" value is distinct from simply having a single shard: it represents
/// a tenant which is not shard-aware at all, and whose storage paths will not include
/// a shard suffix.
pub fn is_unsharded(&self) -> bool {
self.number == ShardNumber(0) && self.count == ShardCount(0)
}
@@ -526,8 +487,6 @@ impl ShardIdentity {
}
/// Return true if the key should be ingested by this shard
///
/// Shards must ingest _at least_ keys which return true from this check.
pub fn is_key_local(&self, key: &Key) -> bool {
assert!(!self.is_broken());
if self.count < ShardCount(2) || (key_is_shard0(key) && self.number == ShardNumber(0)) {
@@ -538,9 +497,7 @@ impl ShardIdentity {
}
/// Return true if the key should be discarded if found in this shard's
/// data store, e.g. during compaction after a split.
///
/// Shards _may_ drop keys which return false here, but are not obliged to.
/// data store, e.g. during compaction after a split
pub fn is_key_disposable(&self, key: &Key) -> bool {
if key_is_shard0(key) {
// Q: Why can't we dispose of shard0 content if we're not shard 0?
@@ -566,7 +523,7 @@ impl ShardIdentity {
/// Convenience for checking if this identity is the 0th shard in a tenant,
/// for special cases on shard 0 such as ingesting relation sizes.
pub fn is_shard_zero(&self) -> bool {
pub fn is_zero(&self) -> bool {
self.number == ShardNumber(0)
}
}

View File

@@ -92,8 +92,6 @@ pub mod zstd;
pub mod env;
pub mod poison;
/// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
///
/// we have several cases:

View File

@@ -1,121 +0,0 @@
//! Protect a piece of state from reuse after it is left in an inconsistent state.
//!
//! # Example
//!
//! ```
//! # tokio_test::block_on(async {
//! use utils::poison::Poison;
//! use std::time::Duration;
//!
//! struct State {
//! clean: bool,
//! }
//! let state = tokio::sync::Mutex::new(Poison::new("mystate", State { clean: true }));
//!
//! let mut mutex_guard = state.lock().await;
//! let mut poison_guard = mutex_guard.check_and_arm()?;
//! let state = poison_guard.data_mut();
//! state.clean = false;
//! // If we get cancelled at this await point, subsequent check_and_arm() calls will fail.
//! tokio::time::sleep(Duration::from_secs(10)).await;
//! state.clean = true;
//! poison_guard.disarm();
//! # Ok::<(), utils::poison::Error>(())
//! # });
//! ```
use tracing::warn;
pub struct Poison<T> {
what: &'static str,
state: State,
data: T,
}
#[derive(Clone, Copy)]
enum State {
Clean,
Armed,
Poisoned { at: chrono::DateTime<chrono::Utc> },
}
impl<T> Poison<T> {
/// We log `what` `warning!` level if the [`Guard`] gets dropped without being [`Guard::disarm`]ed.
pub fn new(what: &'static str, data: T) -> Self {
Self {
what,
state: State::Clean,
data,
}
}
/// Check for poisoning and return a [`Guard`] that provides access to the wrapped state.
pub fn check_and_arm(&mut self) -> Result<Guard<T>, Error> {
match self.state {
State::Clean => {
self.state = State::Armed;
Ok(Guard(self))
}
State::Armed => unreachable!("transient state"),
State::Poisoned { at } => Err(Error::Poisoned {
what: self.what,
at,
}),
}
}
}
/// Use [`Self::data`] and [`Self::data_mut`] to access the wrapped state.
/// Once modifications are done, use [`Self::disarm`].
/// If [`Guard`] gets dropped instead of calling [`Self::disarm`], the state is poisoned
/// and subsequent calls to [`Poison::check_and_arm`] will fail with an error.
pub struct Guard<'a, T>(&'a mut Poison<T>);
impl<'a, T> Guard<'a, T> {
pub fn data(&self) -> &T {
&self.0.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.0.data
}
pub fn disarm(self) {
match self.0.state {
State::Clean => unreachable!("we set it to Armed in check_and_arm()"),
State::Armed => {
self.0.state = State::Clean;
}
State::Poisoned { at } => {
unreachable!("we fail check_and_arm() if it's in that state: {at}")
}
}
}
}
impl<'a, T> Drop for Guard<'a, T> {
fn drop(&mut self) {
match self.0.state {
State::Clean => {
// set by disarm()
}
State::Armed => {
// still armed => poison it
let at = chrono::Utc::now();
self.0.state = State::Poisoned { at };
warn!(at=?at, "poisoning {}", self.0.what);
}
State::Poisoned { at } => {
unreachable!("we fail check_and_arm() if it's in that state: {at}")
}
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("poisoned at {at}: {what}")]
Poisoned {
what: &'static str,
at: chrono::DateTime<chrono::Utc>,
},
}

View File

@@ -27,50 +27,30 @@
//!
//! # Reference Numbers
//!
//! 2024-04-15 on i3en.3xlarge
//! 2024-04-04 on i3en.3xlarge
//!
//! ```text
//! async-short/1 time: [24.584 µs 24.737 µs 24.922 µs]
//! async-short/2 time: [33.479 µs 33.660 µs 33.888 µs]
//! async-short/4 time: [42.713 µs 43.046 µs 43.440 µs]
//! async-short/8 time: [71.814 µs 72.478 µs 73.240 µs]
//! async-short/16 time: [132.73 µs 134.45 µs 136.22 µs]
//! async-short/32 time: [258.31 µs 260.73 µs 263.27 µs]
//! async-short/64 time: [511.61 µs 514.44 µs 517.51 µs]
//! async-short/128 time: [992.64 µs 998.23 µs 1.0042 ms]
//! async-medium/1 time: [110.11 µs 110.50 µs 110.96 µs]
//! async-medium/2 time: [153.06 µs 153.85 µs 154.99 µs]
//! async-medium/4 time: [317.51 µs 319.92 µs 322.85 µs]
//! async-medium/8 time: [638.30 µs 644.68 µs 652.12 µs]
//! async-medium/16 time: [1.2651 ms 1.2773 ms 1.2914 ms]
//! async-medium/32 time: [2.5117 ms 2.5410 ms 2.5720 ms]
//! async-medium/64 time: [4.8088 ms 4.8555 ms 4.9047 ms]
//! async-medium/128 time: [8.8311 ms 8.9849 ms 9.1263 ms]
//! sync-short/1 time: [25.503 µs 25.626 µs 25.771 µs]
//! sync-short/2 time: [30.850 µs 31.013 µs 31.208 µs]
//! sync-short/4 time: [45.543 µs 45.856 µs 46.193 µs]
//! sync-short/8 time: [84.114 µs 84.639 µs 85.220 µs]
//! sync-short/16 time: [185.22 µs 186.15 µs 187.13 µs]
//! sync-short/32 time: [377.43 µs 378.87 µs 380.46 µs]
//! sync-short/64 time: [756.49 µs 759.04 µs 761.70 µs]
//! sync-short/128 time: [1.4825 ms 1.4874 ms 1.4923 ms]
//! sync-medium/1 time: [105.66 µs 106.01 µs 106.43 µs]
//! sync-medium/2 time: [153.10 µs 153.84 µs 154.72 µs]
//! sync-medium/4 time: [327.13 µs 329.44 µs 332.27 µs]
//! sync-medium/8 time: [654.26 µs 658.73 µs 663.63 µs]
//! sync-medium/16 time: [1.2682 ms 1.2748 ms 1.2816 ms]
//! sync-medium/32 time: [2.4456 ms 2.4595 ms 2.4731 ms]
//! sync-medium/64 time: [4.6523 ms 4.6890 ms 4.7256 ms]
//! sync-medium/128 time: [8.7215 ms 8.8323 ms 8.9344 ms]
//! short/1 time: [25.925 µs 26.060 µs 26.209 µs]
//! short/2 time: [31.277 µs 31.483 µs 31.722 µs]
//! short/4 time: [45.496 µs 45.831 µs 46.182 µs]
//! short/8 time: [84.298 µs 84.920 µs 85.566 µs]
//! short/16 time: [185.04 µs 186.41 µs 187.88 µs]
//! short/32 time: [385.01 µs 386.77 µs 388.70 µs]
//! short/64 time: [770.24 µs 773.04 µs 776.04 µs]
//! short/128 time: [1.5017 ms 1.5064 ms 1.5113 ms]
//! medium/1 time: [106.65 µs 107.20 µs 107.85 µs]
//! medium/2 time: [153.28 µs 154.24 µs 155.56 µs]
//! medium/4 time: [325.67 µs 327.01 µs 328.71 µs]
//! medium/8 time: [646.82 µs 650.17 µs 653.91 µs]
//! medium/16 time: [1.2645 ms 1.2701 ms 1.2762 ms]
//! medium/32 time: [2.4409 ms 2.4550 ms 2.4692 ms]
//! medium/64 time: [4.6814 ms 4.7114 ms 4.7408 ms]
//! medium/128 time: [8.7790 ms 8.9037 ms 9.0282 ms]
//! ```
use bytes::{Buf, Bytes};
use criterion::{BenchmarkId, Criterion};
use pageserver::{
config::PageServerConf,
walrecord::NeonWalRecord,
walredo::{PostgresRedoManager, ProcessKind},
};
use pageserver::{config::PageServerConf, walrecord::NeonWalRecord, walredo::PostgresRedoManager};
use pageserver_api::{key::Key, shard::TenantShardId};
use std::{
sync::Arc,
@@ -80,39 +60,33 @@ use tokio::{sync::Barrier, task::JoinSet};
use utils::{id::TenantId, lsn::Lsn};
fn bench(c: &mut Criterion) {
for process_kind in &[ProcessKind::Async, ProcessKind::Sync] {
{
let nclients = [1, 2, 4, 8, 16, 32, 64, 128];
for nclients in nclients {
let mut group = c.benchmark_group(format!("{process_kind}-short"));
group.bench_with_input(
BenchmarkId::from_parameter(nclients),
&nclients,
|b, nclients| {
let redo_work = Arc::new(Request::short_input());
b.iter_custom(|iters| {
bench_impl(*process_kind, Arc::clone(&redo_work), iters, *nclients)
});
},
);
}
{
let nclients = [1, 2, 4, 8, 16, 32, 64, 128];
for nclients in nclients {
let mut group = c.benchmark_group("short");
group.bench_with_input(
BenchmarkId::from_parameter(nclients),
&nclients,
|b, nclients| {
let redo_work = Arc::new(Request::short_input());
b.iter_custom(|iters| bench_impl(Arc::clone(&redo_work), iters, *nclients));
},
);
}
}
{
let nclients = [1, 2, 4, 8, 16, 32, 64, 128];
for nclients in nclients {
let mut group = c.benchmark_group(format!("{process_kind}-medium"));
group.bench_with_input(
BenchmarkId::from_parameter(nclients),
&nclients,
|b, nclients| {
let redo_work = Arc::new(Request::medium_input());
b.iter_custom(|iters| {
bench_impl(*process_kind, Arc::clone(&redo_work), iters, *nclients)
});
},
);
}
{
let nclients = [1, 2, 4, 8, 16, 32, 64, 128];
for nclients in nclients {
let mut group = c.benchmark_group("medium");
group.bench_with_input(
BenchmarkId::from_parameter(nclients),
&nclients,
|b, nclients| {
let redo_work = Arc::new(Request::medium_input());
b.iter_custom(|iters| bench_impl(Arc::clone(&redo_work), iters, *nclients));
},
);
}
}
}
@@ -120,16 +94,10 @@ criterion::criterion_group!(benches, bench);
criterion::criterion_main!(benches);
// Returns the sum of each client's wall-clock time spent executing their share of the n_redos.
fn bench_impl(
process_kind: ProcessKind,
redo_work: Arc<Request>,
n_redos: u64,
nclients: u64,
) -> Duration {
fn bench_impl(redo_work: Arc<Request>, n_redos: u64, nclients: u64) -> Duration {
let repo_dir = camino_tempfile::tempdir_in(env!("CARGO_TARGET_TMPDIR")).unwrap();
let mut conf = PageServerConf::dummy_conf(repo_dir.path().to_path_buf());
conf.walredo_process_kind = process_kind;
let conf = PageServerConf::dummy_conf(repo_dir.path().to_path_buf());
let conf = Box::leak(Box::new(conf));
let tenant_shard_id = TenantShardId::unsharded(TenantId::generate());
@@ -145,40 +113,25 @@ fn bench_impl(
let manager = PostgresRedoManager::new(conf, tenant_shard_id);
let manager = Arc::new(manager);
// divide the amount of work equally among the clients.
let nredos_per_client = n_redos / nclients;
for _ in 0..nclients {
rt.block_on(async {
tasks.spawn(client(
Arc::clone(&manager),
Arc::clone(&start),
Arc::clone(&redo_work),
nredos_per_client,
// divide the amount of work equally among the clients
n_redos / nclients,
))
});
}
let elapsed = rt.block_on(async move {
let mut total_wallclock_time = Duration::ZERO;
rt.block_on(async move {
let mut total_wallclock_time = std::time::Duration::from_millis(0);
while let Some(res) = tasks.join_next().await {
total_wallclock_time += res.unwrap();
}
total_wallclock_time
});
// consistency check to ensure process kind setting worked
if nredos_per_client > 0 {
assert_eq!(
manager
.status()
.process
.map(|p| p.kind)
.expect("the benchmark work causes a walredo process to be spawned"),
std::borrow::Cow::Borrowed(process_kind.into())
);
}
elapsed
})
}
async fn client(

View File

@@ -312,7 +312,11 @@ async fn main_impl(
let (rel_tag, block_no) =
key_to_rel_block(key).expect("we filter non-rel-block keys out above");
PagestreamGetPageRequest {
latest: rng.gen_bool(args.req_latest_probability),
horizon: if rng.gen_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn
},
lsn: r.timeline_lsn,
rel: rel_tag,
blkno: block_no,

View File

@@ -361,9 +361,10 @@ where
/// Add contents of relfilenode `src`, naming it as `dst`.
async fn add_rel(&mut self, src: RelTag, dst: RelTag) -> anyhow::Result<()> {
let horizon = self.lsn; // we do not need latest version
let nblocks = self
.timeline
.get_rel_size(src, Version::Lsn(self.lsn), false, self.ctx)
.get_rel_size(src, Version::Lsn(self.lsn), horizon, self.ctx)
.await?;
// If the relation is empty, create an empty file
@@ -384,7 +385,7 @@ where
for blknum in startblk..endblk {
let img = self
.timeline
.get_rel_page_at_lsn(src, blknum, Version::Lsn(self.lsn), false, self.ctx)
.get_rel_page_at_lsn(src, blknum, Version::Lsn(self.lsn), horizon, self.ctx)
.await?;
segment_data.extend_from_slice(&img[..]);
}

View File

@@ -285,7 +285,6 @@ fn start_pageserver(
))
.unwrap();
pageserver::preinitialize_metrics();
pageserver::metrics::wal_redo::set_process_kind_metric(conf.walredo_process_kind);
// If any failpoints were set from FAILPOINTS environment variable,
// print them to the log for debugging purposes

View File

@@ -97,8 +97,6 @@ pub mod defaults {
pub const DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB: usize = 0;
pub const DEFAULT_WALREDO_PROCESS_KIND: &str = "sync";
///
/// Default built-in configuration file.
///
@@ -142,8 +140,6 @@ pub mod defaults {
#validate_vectored_get = '{DEFAULT_VALIDATE_VECTORED_GET}'
#walredo_process_kind = '{DEFAULT_WALREDO_PROCESS_KIND}'
[tenant_config]
#checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes
#checkpoint_timeout = {DEFAULT_CHECKPOINT_TIMEOUT}
@@ -294,8 +290,6 @@ pub struct PageServerConf {
///
/// Setting this to zero disables limits on total ephemeral layer size.
pub ephemeral_bytes_per_memory_kb: usize,
pub walredo_process_kind: crate::walredo::ProcessKind,
}
/// We do not want to store this in a PageServerConf because the latter may be logged
@@ -419,8 +413,6 @@ struct PageServerConfigBuilder {
validate_vectored_get: BuilderValue<bool>,
ephemeral_bytes_per_memory_kb: BuilderValue<usize>,
walredo_process_kind: BuilderValue<crate::walredo::ProcessKind>,
}
impl PageServerConfigBuilder {
@@ -508,8 +500,6 @@ impl PageServerConfigBuilder {
)),
validate_vectored_get: Set(DEFAULT_VALIDATE_VECTORED_GET),
ephemeral_bytes_per_memory_kb: Set(DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),
walredo_process_kind: Set(DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap()),
}
}
}
@@ -693,10 +683,6 @@ impl PageServerConfigBuilder {
self.ephemeral_bytes_per_memory_kb = BuilderValue::Set(value);
}
pub fn get_walredo_process_kind(&mut self, value: crate::walredo::ProcessKind) {
self.walredo_process_kind = BuilderValue::Set(value);
}
pub fn build(self) -> anyhow::Result<PageServerConf> {
let default = Self::default_values();
@@ -753,7 +739,6 @@ impl PageServerConfigBuilder {
max_vectored_read_bytes,
validate_vectored_get,
ephemeral_bytes_per_memory_kb,
walredo_process_kind,
}
CUSTOM LOGIC
{
@@ -1047,9 +1032,6 @@ impl PageServerConf {
"ephemeral_bytes_per_memory_kb" => {
builder.get_ephemeral_bytes_per_memory_kb(parse_toml_u64("ephemeral_bytes_per_memory_kb", item)? as usize)
}
"walredo_process_kind" => {
builder.get_walredo_process_kind(parse_toml_from_str("walredo_process_kind", item)?)
}
_ => bail!("unrecognized pageserver option '{key}'"),
}
}
@@ -1132,7 +1114,6 @@ impl PageServerConf {
),
validate_vectored_get: defaults::DEFAULT_VALIDATE_VECTORED_GET,
ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB,
walredo_process_kind: defaults::DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap(),
}
}
}
@@ -1370,8 +1351,7 @@ background_task_maximum_delay = '334 s'
.expect("Invalid default constant")
),
validate_vectored_get: defaults::DEFAULT_VALIDATE_VECTORED_GET,
ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB,
walredo_process_kind: defaults::DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap(),
ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB
},
"Correct defaults should be used when no config values are provided"
);
@@ -1443,8 +1423,7 @@ background_task_maximum_delay = '334 s'
.expect("Invalid default constant")
),
validate_vectored_get: defaults::DEFAULT_VALIDATE_VECTORED_GET,
ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB,
walredo_process_kind: defaults::DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap(),
ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB
},
"Should be able to parse all basic config values correctly"
);

View File

@@ -304,7 +304,7 @@ async fn calculate_synthetic_size_worker(
continue;
}
if !tenant_shard_id.is_shard_zero() {
if !tenant_shard_id.is_zero() {
// We only send consumption metrics from shard 0, so don't waste time calculating
// synthetic size on other shards.
continue;

View File

@@ -199,7 +199,7 @@ pub(super) async fn collect_all_metrics(
};
let tenants = futures::stream::iter(tenants).filter_map(|(id, state, _)| async move {
if state != TenantState::Active || !id.is_shard_zero() {
if state != TenantState::Active || !id.is_zero() {
None
} else {
tenant_manager

View File

@@ -58,6 +58,24 @@ paths:
responses:
"200":
description: The reload completed successfully.
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error (also hits if no keys were found)
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}:
parameters:
@@ -75,14 +93,62 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/TenantInfo"
"400":
description: Error when no tenant id found in path or no timeline id
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
delete:
description: |
Attempts to delete specified tenant. 500, 503 and 409 errors should be retried until 404 is retrieved.
404 means that deletion successfully finished"
responses:
"400":
description: Error when no tenant id found in path
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"404":
description: Tenant not found. This is the success path.
description: Tenant not found
content:
application/json:
schema:
@@ -99,6 +165,18 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/PreconditionFailedError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/time_travel_remote_storage:
parameters:
@@ -128,6 +206,36 @@ paths:
application/json:
schema:
type: string
"400":
description: Error when no tenant id found in path or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/timeline:
parameters:
@@ -147,6 +255,36 @@ paths:
type: array
items:
$ref: "#/components/schemas/TimelineInfo"
"400":
description: Error when no tenant id found in path
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/timeline/{timeline_id}:
@@ -171,12 +309,60 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/TimelineInfo"
"400":
description: Error when no tenant id found in path or no timeline id
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
delete:
description: "Attempts to delete specified timeline. 500 and 409 errors should be retried"
responses:
"400":
description: Error when no tenant id found in path or no timeline id
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"404":
description: Timeline not found. This is the success path.
description: Timeline not found
content:
application/json:
schema:
@@ -193,6 +379,18 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/PreconditionFailedError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/get_timestamp_of_lsn:
parameters:
@@ -225,6 +423,36 @@ paths:
schema:
type: string
format: date-time
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"404":
description: Timeline not found, or there is no timestamp information for the given lsn
content:
application/json:
schema:
$ref: "#/components/schemas/NotFoundError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/get_lsn_by_timestamp:
parameters:
@@ -256,6 +484,36 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/LsnByTimestampResponse"
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc:
parameters:
@@ -279,6 +537,36 @@ paths:
application/json:
schema:
type: string
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_shard_id}/location_config:
parameters:
- name: tenant_shard_id
@@ -340,6 +628,24 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/TenantLocationConfigResponse"
"503":
description: Tenant's state cannot be changed right now. Wait a few seconds and retry.
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"409":
description: |
The tenant is already known to Pageserver in some way,
@@ -356,6 +662,12 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ConflictError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/ignore:
parameters:
- name: tenant_id
@@ -372,6 +684,36 @@ paths:
responses:
"200":
description: Tenant ignored
"400":
description: Error when no tenant id found in path parameters
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/load:
@@ -398,6 +740,36 @@ paths:
responses:
"202":
description: Tenant scheduled to load successfully
"400":
description: Error when no tenant id found in path parameters
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/{timeline_id}/preserve_initdb_archive:
parameters:
@@ -418,6 +790,37 @@ paths:
responses:
"202":
description: Tenant scheduled to load successfully
"404":
description: No tenant or timeline found for the specified ids
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/synthetic_size:
parameters:
@@ -436,8 +839,31 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/SyntheticSizeResponse"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
# This route has no handler. TODO: remove?
/v1/tenant/{tenant_id}/size:
parameters:
- name: tenant_id
@@ -519,6 +945,18 @@ paths:
responses:
"200":
description: Success
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_shard_id}/secondary/download:
parameters:
@@ -549,6 +987,20 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/SecondaryProgress"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/timeline/:
parameters:
@@ -591,6 +1043,24 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/TimelineInfo"
"400":
description: Malformed timeline create request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"406":
description: Permanently unsatisfiable request, don't retry.
content:
@@ -609,6 +1079,18 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/:
get:
@@ -622,6 +1104,30 @@ paths:
type: array
items:
$ref: "#/components/schemas/TenantInfo"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
post:
description: |
@@ -642,12 +1148,43 @@ paths:
application/json:
schema:
type: string
"400":
description: Malformed tenant create request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"409":
description: Tenant already exists, creation skipped
content:
application/json:
schema:
$ref: "#/components/schemas/ConflictError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/config:
put:
@@ -669,6 +1206,36 @@ paths:
type: array
items:
$ref: "#/components/schemas/TenantInfo"
"400":
description: Malformed tenant config request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/config/:
parameters:
@@ -688,6 +1255,42 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/TenantConfigResponse"
"400":
description: Malformed get tenanant config request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"404":
description: Tenand or timeline were not found
content:
application/json:
schema:
$ref: "#/components/schemas/NotFoundError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"503":
description: Temporarily unavailable, please retry.
content:
application/json:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/utilization:
get:
@@ -701,6 +1304,12 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/PageserverUtilization"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
components:
securitySchemes:

View File

@@ -457,12 +457,8 @@ async fn reload_auth_validation_keys_handler(
json_response(StatusCode::OK, ())
}
Err(e) => {
let err_msg = "Error reloading public keys";
warn!("Error reloading public keys from {key_path:?}: {e:}");
json_response(
StatusCode::INTERNAL_SERVER_ERROR,
HttpErrorBody::from_msg(err_msg.to_string()),
)
json_response(StatusCode::INTERNAL_SERVER_ERROR, ())
}
}
}
@@ -700,7 +696,7 @@ async fn get_lsn_by_timestamp_handler(
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
let state = get_state(&request);
if !tenant_shard_id.is_shard_zero() {
if !tenant_shard_id.is_zero() {
// Requires SLRU contents, which are only stored on shard zero
return Err(ApiError::BadRequest(anyhow!(
"Size calculations are only available on shard zero"
@@ -751,7 +747,7 @@ async fn get_timestamp_of_lsn_handler(
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
let state = get_state(&request);
if !tenant_shard_id.is_shard_zero() {
if !tenant_shard_id.is_zero() {
// Requires SLRU contents, which are only stored on shard zero
return Err(ApiError::BadRequest(anyhow!(
"Size calculations are only available on shard zero"
@@ -776,9 +772,7 @@ async fn get_timestamp_of_lsn_handler(
let time = format_rfc3339(postgres_ffi::from_pg_timestamp(time)).to_string();
json_response(StatusCode::OK, time)
}
None => Err(ApiError::NotFound(
anyhow::anyhow!("Timestamp for lsn {} not found", lsn).into(),
)),
None => json_response(StatusCode::NOT_FOUND, ()),
}
}
@@ -1092,7 +1086,7 @@ async fn tenant_size_handler(
let headers = request.headers();
let state = get_state(&request);
if !tenant_shard_id.is_shard_zero() {
if !tenant_shard_id.is_zero() {
return Err(ApiError::BadRequest(anyhow!(
"Size calculations are only available on shard zero"
)));

View File

@@ -1819,29 +1819,6 @@ impl Default for WalRedoProcessCounters {
pub(crate) static WAL_REDO_PROCESS_COUNTERS: Lazy<WalRedoProcessCounters> =
Lazy::new(WalRedoProcessCounters::default);
#[cfg(not(test))]
pub mod wal_redo {
use super::*;
static PROCESS_KIND: Lazy<std::sync::Mutex<UIntGaugeVec>> = Lazy::new(|| {
std::sync::Mutex::new(
register_uint_gauge_vec!(
"pageserver_wal_redo_process_kind",
"The configured process kind for walredo",
&["kind"],
)
.unwrap(),
)
});
pub fn set_process_kind_metric(kind: crate::walredo::ProcessKind) {
// use guard to avoid races around the next two steps
let guard = PROCESS_KIND.lock().unwrap();
guard.reset();
guard.with_label_values(&[&format!("{kind}")]).set(1);
}
}
/// Similar to `prometheus::HistogramTimer` but does not record on drop.
pub(crate) struct StorageTimeMetricsTimer {
metrics: StorageTimeMetrics,
@@ -2112,7 +2089,7 @@ impl TimelineMetrics {
pub(crate) fn remove_tenant_metrics(tenant_shard_id: &TenantShardId) {
// Only shard zero deals in synthetic sizes
if tenant_shard_id.is_shard_zero() {
if tenant_shard_id.is_zero() {
let tid = tenant_shard_id.tenant_id.to_string();
let _ = TENANT_SYNTHETIC_SIZE_METRIC.remove_label_values(&[&tid]);
}

View File

@@ -847,69 +847,66 @@ impl PageServerHandler {
/// In either case, if the page server hasn't received the WAL up to the
/// requested LSN yet, we will wait for it to arrive. The return value is
/// the LSN that should be used to look up the page versions.
///
/// Compute needs to specify:
/// 1. "desired" LSN - which LSN compute expects to be acceptable
/// 2. Upper boundary LSN - PS should not send page with greater LSN to preserver consistency
///
/// In case of primary node then upper boundary is always +inf: nobody except this node can produce more recent version of the page.
/// In case of replica it is not true: replica can lag from primary node and PS and should not receive pages newer than its last_replay_lsn.
/// But it is not good always to request pages at `last_replay_lsn` because replica can be ahead PS and so it has to wait
/// until PS caught up (while for this particular page it is not needed).
///
/// We actually need to handle just three cases:
/// \[page_last_written_lsn, +inf\] - primary node
/// \[page_last_written_lsn, last_replay_lsn\] - hot-standby replica (receiving WAL from primary)
/// \[snapshot_lsn, snapshot_lsn\] - static RO replica (not receiving WAL fro primary)
///
/// Case \[0, lsn\] is not actually needed and added mostly for convenience as alias for \[lsn,lsn\]
async fn wait_or_get_last_lsn(
timeline: &Timeline,
mut lsn: Lsn,
latest: bool,
lsn: Lsn,
horizon: Lsn,
latest_gc_cutoff_lsn: &RcuReadGuard<Lsn>,
ctx: &RequestContext,
) -> Result<Lsn, PageStreamError> {
if latest {
// Latest page version was requested. If LSN is given, it is a hint
// to the page server that there have been no modifications to the
// page after that LSN. If we haven't received WAL up to that point,
// wait until it arrives.
let last_record_lsn = timeline.get_last_record_lsn();
// Note: this covers the special case that lsn == Lsn(0). That
// special case means "return the latest version whatever it is",
// and it's used for bootstrapping purposes, when the page server is
// connected directly to the compute node. That is needed because
// when you connect to the compute node, to receive the WAL, the
// walsender process will do a look up in the pg_authid catalog
// table for authentication. That poses a deadlock problem: the
// catalog table lookup will send a GetPage request, but the GetPage
// request will block in the page server because the recent WAL
// hasn't been received yet, and it cannot be received until the
// walsender completes the authentication and starts streaming the
// WAL.
if lsn <= last_record_lsn {
lsn = last_record_lsn;
} else {
timeline
.wait_lsn(
lsn,
crate::tenant::timeline::WaitLsnWaiter::PageService,
ctx,
)
.await?;
// Since we waited for 'lsn' to arrive, that is now the last
// record LSN. (Or close enough for our purposes; the
// last-record LSN can advance immediately after we return
// anyway)
}
let last_record_lsn = timeline.get_last_record_lsn();
// Horizon = 0 (INVALID) is treated as LSN interval degenerated to point [lsn,lsn].
// It as done mostly for convenience (because such get_page commands are widely used in tests) and
// also seems to be logical: Lsn::MAX moves upper boundary of LSN interval till last_record_lsn and
// Lsn(0) moves upper boundary to lower boundary.
let request_horizon = if horizon == Lsn::INVALID {
lsn
} else {
if lsn == Lsn(0) {
return Err(PageStreamError::BadRequest(
"invalid LSN(0) in request".into(),
));
}
horizon
};
let effective_lsn = Lsn::max(lsn, Lsn::min(request_horizon, last_record_lsn));
if effective_lsn > last_record_lsn {
timeline
.wait_lsn(
lsn,
effective_lsn,
crate::tenant::timeline::WaitLsnWaiter::PageService,
ctx,
)
.await?;
// Since we waited for 'lsn' to arrive, that is now the last
// record LSN. (Or close enough for our purposes; the
// last-record LSN can advance immediately after we return
// anyway)
} else if effective_lsn == Lsn(0) {
return Err(PageStreamError::BadRequest(
"invalid LSN(0) in request".into(),
));
}
if lsn < **latest_gc_cutoff_lsn {
if effective_lsn < **latest_gc_cutoff_lsn {
return Err(PageStreamError::BadRequest(format!(
"tried to request a page version that was garbage collected. requested at {} gc cutoff {}",
lsn, **latest_gc_cutoff_lsn
effective_lsn, **latest_gc_cutoff_lsn
).into()));
}
Ok(lsn)
Ok(effective_lsn)
}
#[instrument(skip_all, fields(shard_id))]
@@ -927,11 +924,11 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let exists = timeline
.get_rel_exists(req.rel, Version::Lsn(lsn), req.latest, ctx)
.get_rel_exists(req.rel, Version::Lsn(lsn), req.horizon, ctx)
.await?;
Ok(PagestreamBeMessage::Exists(PagestreamExistsResponse {
@@ -955,11 +952,11 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let n_blocks = timeline
.get_rel_size(req.rel, Version::Lsn(lsn), req.latest, ctx)
.get_rel_size(req.rel, Version::Lsn(lsn), req.horizon, ctx)
.await?;
Ok(PagestreamBeMessage::Nblocks(PagestreamNblocksResponse {
@@ -983,7 +980,7 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let total_blocks = timeline
@@ -991,7 +988,7 @@ impl PageServerHandler {
DEFAULTTABLESPACE_OID,
req.dbnode,
Version::Lsn(lsn),
req.latest,
req.horizon,
ctx,
)
.await?;
@@ -1161,11 +1158,11 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let page = timeline
.get_rel_page_at_lsn(req.rel, req.blkno, Version::Lsn(lsn), req.latest, ctx)
.get_rel_page_at_lsn(req.rel, req.blkno, Version::Lsn(lsn), req.horizon, ctx)
.await?;
Ok(PagestreamBeMessage::GetPage(PagestreamGetPageResponse {
@@ -1189,7 +1186,7 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let kind = SlruKind::from_repr(req.kind)

View File

@@ -175,7 +175,7 @@ impl Timeline {
tag: RelTag,
blknum: BlockNumber,
version: Version<'_>,
latest: bool,
horizon: Lsn,
ctx: &RequestContext,
) -> Result<Bytes, PageReconstructError> {
if tag.relnode == 0 {
@@ -184,7 +184,7 @@ impl Timeline {
));
}
let nblocks = self.get_rel_size(tag, version, latest, ctx).await?;
let nblocks = self.get_rel_size(tag, version, horizon, ctx).await?;
if blknum >= nblocks {
debug!(
"read beyond EOF at {} blk {} at {}, size is {}: returning all-zeros page",
@@ -206,7 +206,7 @@ impl Timeline {
spcnode: Oid,
dbnode: Oid,
version: Version<'_>,
latest: bool,
horizon: Lsn,
ctx: &RequestContext,
) -> Result<usize, PageReconstructError> {
let mut total_blocks = 0;
@@ -214,7 +214,7 @@ impl Timeline {
let rels = self.list_rels(spcnode, dbnode, version, ctx).await?;
for rel in rels {
let n_blocks = self.get_rel_size(rel, version, latest, ctx).await?;
let n_blocks = self.get_rel_size(rel, version, horizon, ctx).await?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -225,7 +225,7 @@ impl Timeline {
&self,
tag: RelTag,
version: Version<'_>,
latest: bool,
horizon: Lsn,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
if tag.relnode == 0 {
@@ -239,7 +239,7 @@ impl Timeline {
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
&& !self.get_rel_exists(tag, version, latest, ctx).await?
&& !self.get_rel_exists(tag, version, horizon, ctx).await?
{
// FIXME: Postgres sometimes calls smgrcreate() to create
// FSM, and smgrnblocks() on it immediately afterwards,
@@ -252,14 +252,8 @@ impl Timeline {
let mut buf = version.get(self, key, ctx).await?;
let nblocks = buf.get_u32_le();
if latest {
// Update relation size cache only if "latest" flag is set.
// This flag is set by compute when it is working with most recent version of relation.
// Typically master compute node always set latest=true.
// Please notice, that even if compute node "by mistake" specifies old LSN but set
// latest=true, then it can not cause cache corruption, because with latest=true
// pageserver choose max(request_lsn, last_written_lsn) and so cached value will be
// associated with most recent value of LSN.
if horizon == Lsn::MAX {
// Update relation size cache only if latest version is requested.
self.update_cached_rel_size(tag, version.get_lsn(), nblocks);
}
Ok(nblocks)
@@ -270,7 +264,7 @@ impl Timeline {
&self,
tag: RelTag,
version: Version<'_>,
_latest: bool,
_horizon: Lsn,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
if tag.relnode == 0 {
@@ -1088,7 +1082,7 @@ impl<'a> DatadirModification<'a> {
) -> anyhow::Result<()> {
let total_blocks = self
.tline
.get_db_size(spcnode, dbnode, Version::Modified(self), true, ctx)
.get_db_size(spcnode, dbnode, Version::Modified(self), Lsn::MAX, ctx)
.await?;
// Remove entry from dbdir
@@ -1187,7 +1181,7 @@ impl<'a> DatadirModification<'a> {
anyhow::ensure!(rel.relnode != 0, RelationError::InvalidRelnode);
if self
.tline
.get_rel_exists(rel, Version::Modified(self), true, ctx)
.get_rel_exists(rel, Version::Modified(self), Lsn::MAX, ctx)
.await?
{
let size_key = rel_size_to_key(rel);

View File

@@ -386,7 +386,7 @@ impl WalRedoManager {
pub(crate) fn status(&self) -> Option<WalRedoManagerStatus> {
match self {
WalRedoManager::Prod(m) => Some(m.status()),
WalRedoManager::Prod(m) => m.status(),
#[cfg(test)]
WalRedoManager::Test(_) => None,
}
@@ -3190,7 +3190,7 @@ impl Tenant {
run_initdb(self.conf, &pgdata_path, pg_version, &self.cancel).await?;
// Upload the created data dir to S3
if self.tenant_shard_id().is_shard_zero() {
if self.tenant_shard_id().is_zero() {
self.upload_initdb(&timelines_path, &pgdata_path, &timeline_id)
.await?;
}
@@ -3437,7 +3437,7 @@ impl Tenant {
.store(size, Ordering::Relaxed);
// Only shard zero should be calculating synthetic sizes
debug_assert!(self.shard_identity.is_shard_zero());
debug_assert!(self.shard_identity.is_zero());
TENANT_SYNTHETIC_SIZE_METRIC
.get_metric_with_label_values(&[&self.tenant_shard_id.tenant_id.to_string()])

View File

@@ -436,11 +436,6 @@ impl DeleteTenantFlow {
.await
}
/// Check whether background deletion of this tenant is currently in progress
pub(crate) fn is_in_progress(tenant: &Tenant) -> bool {
tenant.delete_progress.try_lock().is_err()
}
async fn prepare(
tenant: &Arc<Tenant>,
) -> Result<tokio::sync::OwnedMutexGuard<Self>, DeleteTenantError> {

View File

@@ -1410,15 +1410,9 @@ impl TenantManager {
match tenant.current_state() {
TenantState::Broken { .. } | TenantState::Stopping { .. } => {
// If deletion is already in progress, return success (the semantics of this
// function are to rerturn success afterr deletion is spawned in background).
// Otherwise fall through and let [`DeleteTenantFlow`] handle this state.
if DeleteTenantFlow::is_in_progress(&tenant) {
// The `delete_progress` lock is held: deletion is already happening
// in the bacckground
slot_guard.revert();
return Ok(());
}
// If a tenant is broken or stopping, DeleteTenantFlow can
// handle it: broken tenants proceed to delete, stopping tenants
// are checked for deletion already in progress.
}
_ => {
tenant

View File

@@ -167,7 +167,7 @@ pub(crate) async fn time_travel_recover_tenant(
let warn_after = 3;
let max_attempts = 10;
let mut prefixes = Vec::with_capacity(2);
if tenant_shard_id.is_shard_zero() {
if tenant_shard_id.is_zero() {
// Also recover the unsharded prefix for a shard of zero:
// - if the tenant is totally unsharded, the unsharded prefix contains all the data
// - if the tenant is sharded, we still want to recover the initdb data, but we only

View File

@@ -1344,7 +1344,7 @@ impl Timeline {
background_jobs_can_start: Option<&completion::Barrier>,
ctx: &RequestContext,
) {
if self.tenant_shard_id.is_shard_zero() {
if self.tenant_shard_id.is_zero() {
// Logical size is only maintained accurately on shard zero.
self.spawn_initial_logical_size_computation_task(ctx);
}
@@ -2237,7 +2237,7 @@ impl Timeline {
priority: GetLogicalSizePriority,
ctx: &RequestContext,
) -> logical_size::CurrentLogicalSize {
if !self.tenant_shard_id.is_shard_zero() {
if !self.tenant_shard_id.is_zero() {
// Logical size is only accurately maintained on shard zero: when called elsewhere, for example
// when HTTP API is serving a GET for timeline zero, return zero
return logical_size::CurrentLogicalSize::Approximate(logical_size::Approximate::zero());
@@ -2533,7 +2533,7 @@ impl Timeline {
crate::span::debug_assert_current_span_has_tenant_and_timeline_id();
// We should never be calculating logical sizes on shard !=0, because these shards do not have
// accurate relation sizes, and they do not emit consumption metrics.
debug_assert!(self.tenant_shard_id.is_shard_zero());
debug_assert!(self.tenant_shard_id.is_zero());
let guard = self
.gate

View File

@@ -378,7 +378,7 @@ impl Timeline {
gate: &GateGuard,
ctx: &RequestContext,
) -> ControlFlow<()> {
if !self.tenant_shard_id.is_shard_zero() {
if !self.tenant_shard_id.is_zero() {
// Shards !=0 do not maintain accurate relation sizes, and do not need to calculate logical size
// for consumption metrics (consumption metrics are only sent from shard 0). We may therefore
// skip imitating logical size accesses for eviction purposes.

View File

@@ -427,7 +427,7 @@ pub(super) async fn handle_walreceiver_connection(
// Send the replication feedback message.
// Regular standby_status_update fields are put into this message.
let current_timeline_size = if timeline.tenant_shard_id.is_shard_zero() {
let current_timeline_size = if timeline.tenant_shard_id.is_zero() {
timeline
.get_current_logical_size(
crate::tenant::timeline::GetLogicalSizePriority::User,

View File

@@ -403,7 +403,7 @@ impl WalIngest {
);
if !key_is_local {
if self.shard.is_shard_zero() {
if self.shard.is_zero() {
// Shard 0 tracks relation sizes. Although we will not store this block, we will observe
// its blkno in case it implicitly extends a relation.
self.observe_decoded_block(modification, blk, ctx).await?;
@@ -1034,7 +1034,7 @@ impl WalIngest {
let nblocks = modification
.tline
.get_rel_size(src_rel, Version::Modified(modification), true, ctx)
.get_rel_size(src_rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?;
let dst_rel = RelTag {
spcnode: tablespace_id,
@@ -1072,7 +1072,7 @@ impl WalIngest {
src_rel,
blknum,
Version::Modified(modification),
true,
Lsn::MAX,
ctx,
)
.await?;
@@ -1242,7 +1242,7 @@ impl WalIngest {
};
if modification
.tline
.get_rel_exists(rel, Version::Modified(modification), true, ctx)
.get_rel_exists(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
{
self.put_rel_drop(modification, rel, ctx).await?;
@@ -1541,7 +1541,7 @@ impl WalIngest {
nblocks
} else if !modification
.tline
.get_rel_exists(rel, Version::Modified(modification), true, ctx)
.get_rel_exists(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
{
// create it with 0 size initially, the logic below will extend it
@@ -1553,7 +1553,7 @@ impl WalIngest {
} else {
modification
.tline
.get_rel_size(rel, Version::Modified(modification), true, ctx)
.get_rel_size(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
};
@@ -1650,14 +1650,14 @@ async fn get_relsize(
) -> anyhow::Result<BlockNumber> {
let nblocks = if !modification
.tline
.get_rel_exists(rel, Version::Modified(modification), true, ctx)
.get_rel_exists(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
{
0
} else {
modification
.tline
.get_rel_size(rel, Version::Modified(modification), true, ctx)
.get_rel_size(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
};
Ok(nblocks)
@@ -1732,29 +1732,29 @@ mod tests {
// The relation was created at LSN 2, not visible at LSN 1 yet.
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await?,
false
);
assert!(tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await
.is_err());
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
1
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
3
);
@@ -1762,46 +1762,46 @@ mod tests {
// Check page contents at each LSN
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 2")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x30)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x30)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1 at 4")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1 at 4")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 2 at 5")
);
@@ -1817,19 +1817,19 @@ mod tests {
// Check reported size and contents after truncation
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
2
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1 at 4")
);
@@ -1837,13 +1837,13 @@ mod tests {
// should still see the truncated block with older LSN
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
3
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 2 at 5")
);
@@ -1856,7 +1856,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x68)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x68)), Lsn::INVALID, &ctx)
.await?,
0
);
@@ -1869,19 +1869,19 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x70)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x70)), Lsn::INVALID, &ctx)
.await?,
2
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x70)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x70)), Lsn::INVALID, &ctx)
.await?,
ZERO_PAGE
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x70)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x70)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1")
);
@@ -1894,21 +1894,27 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
1501
);
for blk in 2..1500 {
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blk, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blk,
Version::Lsn(Lsn(0x80)),
Lsn::INVALID,
&ctx
)
.await?,
ZERO_PAGE
);
}
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1500, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1500, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1500")
);
@@ -1935,13 +1941,13 @@ mod tests {
// Check that rel exists and size is correct
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
1
);
@@ -1954,7 +1960,7 @@ mod tests {
// Check that rel is not visible anymore
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x30)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x30)), Lsn::INVALID, &ctx)
.await?,
false
);
@@ -1972,13 +1978,13 @@ mod tests {
// Check that rel exists and size is correct
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
1
);
@@ -2011,24 +2017,24 @@ mod tests {
// The relation was created at LSN 20, not visible at LSN 1 yet.
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await?,
false
);
assert!(tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await
.is_err());
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
relsize
);
@@ -2039,7 +2045,7 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(lsn), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(lsn), Lsn::INVALID, &ctx)
.await?,
test_img(&data)
);
@@ -2056,7 +2062,7 @@ mod tests {
// Check reported size and contents after truncation
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
1
);
@@ -2066,7 +2072,13 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blkno,
Version::Lsn(Lsn(0x60)),
Lsn::INVALID,
&ctx
)
.await?,
test_img(&data)
);
@@ -2075,7 +2087,7 @@ mod tests {
// should still see all blocks with older LSN
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
relsize
);
@@ -2084,7 +2096,13 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blkno,
Version::Lsn(Lsn(0x50)),
Lsn::INVALID,
&ctx
)
.await?,
test_img(&data)
);
@@ -2104,13 +2122,13 @@ mod tests {
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
relsize
);
@@ -2120,7 +2138,13 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blkno,
Version::Lsn(Lsn(0x80)),
Lsn::INVALID,
&ctx
)
.await?,
test_img(&data)
);
@@ -2154,7 +2178,7 @@ mod tests {
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
RELSEG_SIZE + 1
);
@@ -2168,7 +2192,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
RELSEG_SIZE
);
@@ -2183,7 +2207,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
RELSEG_SIZE - 1
);
@@ -2201,7 +2225,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
size as BlockNumber
);

View File

@@ -20,7 +20,6 @@
/// Process lifecycle and abstracction for the IPC protocol.
mod process;
pub use process::Kind as ProcessKind;
/// Code to apply [`NeonWalRecord`]s.
pub(crate) mod apply_neon;
@@ -35,7 +34,7 @@ use crate::walrecord::NeonWalRecord;
use anyhow::Context;
use bytes::{Bytes, BytesMut};
use pageserver_api::key::key_to_rel_block;
use pageserver_api::models::{WalRedoManagerProcessStatus, WalRedoManagerStatus};
use pageserver_api::models::WalRedoManagerStatus;
use pageserver_api::shard::TenantShardId;
use std::sync::Arc;
use std::time::Duration;
@@ -55,7 +54,7 @@ pub struct PostgresRedoManager {
tenant_shard_id: TenantShardId,
conf: &'static PageServerConf,
last_redo_at: std::sync::Mutex<Option<Instant>>,
/// The current [`process::Process`] that is used by new redo requests.
/// The current [`process::WalRedoProcess`] that is used by new redo requests.
/// We use [`heavier_once_cell`] for coalescing the spawning, but the redo
/// requests don't use the [`heavier_once_cell::Guard`] to keep ahold of the
/// their process object; we use [`Arc::clone`] for that.
@@ -67,7 +66,7 @@ pub struct PostgresRedoManager {
/// still be using the old redo process. But, those other tasks will most likely
/// encounter an error as well, and errors are an unexpected condition anyway.
/// So, probably we could get rid of the `Arc` in the future.
redo_process: heavier_once_cell::OnceCell<Arc<process::Process>>,
redo_process: heavier_once_cell::OnceCell<Arc<process::WalRedoProcess>>,
}
///
@@ -140,8 +139,8 @@ impl PostgresRedoManager {
}
}
pub fn status(&self) -> WalRedoManagerStatus {
WalRedoManagerStatus {
pub(crate) fn status(&self) -> Option<WalRedoManagerStatus> {
Some(WalRedoManagerStatus {
last_redo_at: {
let at = *self.last_redo_at.lock().unwrap();
at.and_then(|at| {
@@ -150,14 +149,8 @@ impl PostgresRedoManager {
chrono::Utc::now().checked_sub_signed(chrono::Duration::from_std(age).ok()?)
})
},
process: self
.redo_process
.get()
.map(|p| WalRedoManagerProcessStatus {
pid: p.id(),
kind: std::borrow::Cow::Borrowed(p.kind().into()),
}),
}
pid: self.redo_process.get().map(|p| p.id()),
})
}
}
@@ -215,33 +208,37 @@ impl PostgresRedoManager {
const MAX_RETRY_ATTEMPTS: u32 = 1;
let mut n_attempts = 0u32;
loop {
let proc: Arc<process::Process> = match self.redo_process.get_or_init_detached().await {
Ok(guard) => Arc::clone(&guard),
Err(permit) => {
// don't hold poison_guard, the launch code can bail
let start = Instant::now();
let proc = Arc::new(
process::Process::launch(self.conf, self.tenant_shard_id, pg_version)
let proc: Arc<process::WalRedoProcess> =
match self.redo_process.get_or_init_detached().await {
Ok(guard) => Arc::clone(&guard),
Err(permit) => {
// don't hold poison_guard, the launch code can bail
let start = Instant::now();
let proc = Arc::new(
process::WalRedoProcess::launch(
self.conf,
self.tenant_shard_id,
pg_version,
)
.context("launch walredo process")?,
);
let duration = start.elapsed();
WAL_REDO_PROCESS_LAUNCH_DURATION_HISTOGRAM.observe(duration.as_secs_f64());
info!(
duration_ms = duration.as_millis(),
pid = proc.id(),
"launched walredo process"
);
self.redo_process.set(Arc::clone(&proc), permit);
proc
}
};
);
let duration = start.elapsed();
WAL_REDO_PROCESS_LAUNCH_DURATION_HISTOGRAM.observe(duration.as_secs_f64());
info!(
duration_ms = duration.as_millis(),
pid = proc.id(),
"launched walredo process"
);
self.redo_process.set(Arc::clone(&proc), permit);
proc
}
};
let started_at = std::time::Instant::now();
// Relational WAL records are applied using wal-redo-postgres
let result = proc
.apply_wal_records(rel, blknum, &base_img, records, wal_redo_timeout)
.await
.context("apply_wal_records");
let duration = started_at.elapsed();

View File

@@ -1,67 +1,186 @@
use std::time::Duration;
use self::no_leak_child::NoLeakChild;
use crate::{
config::PageServerConf,
metrics::{WalRedoKillCause, WAL_REDO_PROCESS_COUNTERS, WAL_REDO_RECORD_COUNTER},
walrecord::NeonWalRecord,
};
use anyhow::Context;
use bytes::Bytes;
use nix::poll::{PollFd, PollFlags};
use pageserver_api::{reltag::RelTag, shard::TenantShardId};
use utils::lsn::Lsn;
use crate::{config::PageServerConf, walrecord::NeonWalRecord};
use postgres_ffi::BLCKSZ;
use std::os::fd::AsRawFd;
#[cfg(feature = "testing")]
use std::sync::atomic::AtomicUsize;
use std::{
collections::VecDeque,
io::{Read, Write},
process::{ChildStdin, ChildStdout, Command, Stdio},
sync::{Mutex, MutexGuard},
time::Duration,
};
use tracing::{debug, error, instrument, Instrument};
use utils::{lsn::Lsn, nonblock::set_nonblock};
mod no_leak_child;
/// The IPC protocol that pageserver and walredo process speak over their shared pipe.
mod protocol;
mod process_impl {
pub(super) mod process_async;
pub(super) mod process_std;
pub struct WalRedoProcess {
#[allow(dead_code)]
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
// Some() on construction, only becomes None on Drop.
child: Option<NoLeakChild>,
stdout: Mutex<ProcessOutput>,
stdin: Mutex<ProcessInput>,
/// Counter to separate same sized walredo inputs failing at the same millisecond.
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize,
}
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Eq,
strum_macros::EnumString,
strum_macros::Display,
strum_macros::IntoStaticStr,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
)]
#[strum(serialize_all = "kebab-case")]
#[repr(u8)]
pub enum Kind {
Sync,
Async,
struct ProcessInput {
stdin: ChildStdin,
n_requests: usize,
}
pub(crate) enum Process {
Sync(process_impl::process_std::WalRedoProcess),
Async(process_impl::process_async::WalRedoProcess),
struct ProcessOutput {
stdout: ChildStdout,
pending_responses: VecDeque<Option<Bytes>>,
n_processed_responses: usize,
}
impl Process {
#[inline(always)]
pub fn launch(
impl WalRedoProcess {
//
// Start postgres binary in special WAL redo mode.
//
#[instrument(skip_all,fields(pg_version=pg_version))]
pub(crate) fn launch(
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
pg_version: u32,
) -> anyhow::Result<Self> {
Ok(match conf.walredo_process_kind {
Kind::Sync => Self::Sync(process_impl::process_std::WalRedoProcess::launch(
conf,
tenant_shard_id,
pg_version,
)?),
Kind::Async => Self::Async(process_impl::process_async::WalRedoProcess::launch(
conf,
tenant_shard_id,
pg_version,
)?),
crate::span::debug_assert_current_span_has_tenant_id();
let pg_bin_dir_path = conf.pg_bin_dir(pg_version).context("pg_bin_dir")?; // TODO these should be infallible.
let pg_lib_dir_path = conf.pg_lib_dir(pg_version).context("pg_lib_dir")?;
use no_leak_child::NoLeakChildCommandExt;
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
// the first arg must be --wal-redo so the child process enters into walredo mode
.arg("--wal-redo")
// the child doesn't process this arg, but, having it in the argv helps indentify the
// walredo process for a particular tenant when debugging a pagserver
.args(["--tenant-shard-id", &format!("{tenant_shard_id}")])
.stdin(Stdio::piped())
.stderr(Stdio::piped())
.stdout(Stdio::piped())
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
// NB: The redo process is not trusted after we sent it the first
// walredo work. Before that, it is trusted. Specifically, we trust
// it to
// 1. close all file descriptors except stdin, stdout, stderr because
// pageserver might not be 100% diligent in setting FD_CLOEXEC on all
// the files it opens, and
// 2. to use seccomp to sandbox itself before processing the first
// walredo request.
.spawn_no_leak_child(tenant_shard_id)
.context("spawn process")?;
WAL_REDO_PROCESS_COUNTERS.started.inc();
let mut child = scopeguard::guard(child, |child| {
error!("killing wal-redo-postgres process due to a problem during launch");
child.kill_and_wait(WalRedoKillCause::Startup);
});
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let stderr = tokio::process::ChildStderr::from_std(stderr)
.context("convert to tokio::ChildStderr")?;
macro_rules! set_nonblock_or_log_err {
($file:ident) => {{
let res = set_nonblock($file.as_raw_fd());
if let Err(e) = &res {
error!(error = %e, file = stringify!($file), pid = child.id(), "set_nonblock failed");
}
res
}};
}
set_nonblock_or_log_err!(stdin)?;
set_nonblock_or_log_err!(stdout)?;
// all fallible operations post-spawn are complete, so get rid of the guard
let child = scopeguard::ScopeGuard::into_inner(child);
tokio::spawn(
async move {
scopeguard::defer! {
debug!("wal-redo-postgres stderr_logger_task finished");
crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_finished.inc();
}
debug!("wal-redo-postgres stderr_logger_task started");
crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_started.inc();
use tokio::io::AsyncBufReadExt;
let mut stderr_lines = tokio::io::BufReader::new(stderr);
let mut buf = Vec::new();
let res = loop {
buf.clear();
// TODO we don't trust the process to cap its stderr length.
// Currently it can do unbounded Vec allocation.
match stderr_lines.read_until(b'\n', &mut buf).await {
Ok(0) => break Ok(()), // eof
Ok(num_bytes) => {
let output = String::from_utf8_lossy(&buf[..num_bytes]);
error!(%output, "received output");
}
Err(e) => {
break Err(e);
}
}
};
match res {
Ok(()) => (),
Err(e) => {
error!(error=?e, "failed to read from walredo stderr");
}
}
}.instrument(tracing::info_span!(parent: None, "wal-redo-postgres-stderr", pid = child.id(), tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %pg_version))
);
Ok(Self {
conf,
tenant_shard_id,
child: Some(child),
stdin: Mutex::new(ProcessInput {
stdin,
n_requests: 0,
}),
stdout: Mutex::new(ProcessOutput {
stdout,
pending_responses: VecDeque::new(),
n_processed_responses: 0,
}),
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize::default(),
})
}
#[inline(always)]
pub(crate) async fn apply_wal_records(
pub(crate) fn id(&self) -> u32 {
self.child
.as_ref()
.expect("must not call this during Drop")
.id()
}
// Apply given WAL records ('records') over an old page image. Returns
// new page image.
//
#[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), pid=%self.id()))]
pub(crate) fn apply_wal_records(
&self,
rel: RelTag,
blknum: u32,
@@ -69,29 +188,221 @@ impl Process {
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> anyhow::Result<Bytes> {
match self {
Process::Sync(p) => {
p.apply_wal_records(rel, blknum, base_img, records, wal_redo_timeout)
.await
let tag = protocol::BufferTag { rel, blknum };
let input = self.stdin.lock().unwrap();
// Serialize all the messages to send the WAL redo process first.
//
// This could be problematic if there are millions of records to replay,
// but in practice the number of records is usually so small that it doesn't
// matter, and it's better to keep this code simple.
//
// Most requests start with a before-image with BLCKSZ bytes, followed by
// by some other WAL records. Start with a buffer that can hold that
// comfortably.
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
protocol::build_begin_redo_for_block_msg(tag, &mut writebuf);
if let Some(img) = base_img {
protocol::build_push_page_msg(tag, img, &mut writebuf);
}
for (lsn, rec) in records.iter() {
if let NeonWalRecord::Postgres {
will_init: _,
rec: postgres_rec,
} = rec
{
protocol::build_apply_record_msg(*lsn, postgres_rec, &mut writebuf);
} else {
anyhow::bail!("tried to pass neon wal record to postgres WAL redo");
}
Process::Async(p) => {
p.apply_wal_records(rel, blknum, base_img, records, wal_redo_timeout)
.await
}
protocol::build_get_page_msg(tag, &mut writebuf);
WAL_REDO_RECORD_COUNTER.inc_by(records.len() as u64);
let res = self.apply_wal_records0(&writebuf, input, wal_redo_timeout);
if res.is_err() {
// not all of these can be caused by this particular input, however these are so rare
// in tests so capture all.
self.record_and_log(&writebuf);
}
res
}
fn apply_wal_records0(
&self,
writebuf: &[u8],
input: MutexGuard<ProcessInput>,
wal_redo_timeout: Duration,
) -> anyhow::Result<Bytes> {
let mut proc = { input }; // TODO: remove this legacy rename, but this keep the patch small.
let mut nwrite = 0usize;
while nwrite < writebuf.len() {
let mut stdin_pollfds = [PollFd::new(&proc.stdin, PollFlags::POLLOUT)];
let n = loop {
match nix::poll::poll(&mut stdin_pollfds[..], wal_redo_timeout.as_millis() as i32) {
Err(nix::errno::Errno::EINTR) => continue,
res => break res,
}
}?;
if n == 0 {
anyhow::bail!("WAL redo timed out");
}
// If 'stdin' is writeable, do write.
let in_revents = stdin_pollfds[0].revents().unwrap();
if in_revents & (PollFlags::POLLERR | PollFlags::POLLOUT) != PollFlags::empty() {
nwrite += proc.stdin.write(&writebuf[nwrite..])?;
}
if in_revents.contains(PollFlags::POLLHUP) {
// We still have more data to write, but the process closed the pipe.
anyhow::bail!("WAL redo process closed its stdin unexpectedly");
}
}
let request_no = proc.n_requests;
proc.n_requests += 1;
drop(proc);
// To improve walredo performance we separate sending requests and receiving
// responses. Them are protected by different mutexes (output and input).
// If thread T1, T2, T3 send requests D1, D2, D3 to walredo process
// then there is not warranty that T1 will first granted output mutex lock.
// To address this issue we maintain number of sent requests, number of processed
// responses and ring buffer with pending responses. After sending response
// (under input mutex), threads remembers request number. Then it releases
// input mutex, locks output mutex and fetch in ring buffer all responses until
// its stored request number. The it takes correspondent element from
// pending responses ring buffer and truncate all empty elements from the front,
// advancing processed responses number.
let mut output = self.stdout.lock().unwrap();
let n_processed_responses = output.n_processed_responses;
while n_processed_responses + output.pending_responses.len() <= request_no {
// We expect the WAL redo process to respond with an 8k page image. We read it
// into this buffer.
let mut resultbuf = vec![0; BLCKSZ.into()];
let mut nresult: usize = 0; // # of bytes read into 'resultbuf' so far
while nresult < BLCKSZ.into() {
let mut stdout_pollfds = [PollFd::new(&output.stdout, PollFlags::POLLIN)];
// We do two things simultaneously: reading response from stdout
// and forward any logging information that the child writes to its stderr to the page server's log.
let n = loop {
match nix::poll::poll(
&mut stdout_pollfds[..],
wal_redo_timeout.as_millis() as i32,
) {
Err(nix::errno::Errno::EINTR) => continue,
res => break res,
}
}?;
if n == 0 {
anyhow::bail!("WAL redo timed out");
}
// If we have some data in stdout, read it to the result buffer.
let out_revents = stdout_pollfds[0].revents().unwrap();
if out_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() {
nresult += output.stdout.read(&mut resultbuf[nresult..])?;
}
if out_revents.contains(PollFlags::POLLHUP) {
anyhow::bail!("WAL redo process closed its stdout unexpectedly");
}
}
output
.pending_responses
.push_back(Some(Bytes::from(resultbuf)));
}
// Replace our request's response with None in `pending_responses`.
// Then make space in the ring buffer by clearing out any seqence of contiguous
// `None`'s from the front of `pending_responses`.
// NB: We can't pop_front() because other requests' responses because another
// requester might have grabbed the output mutex before us:
// T1: grab input mutex
// T1: send request_no 23
// T1: release input mutex
// T2: grab input mutex
// T2: send request_no 24
// T2: release input mutex
// T2: grab output mutex
// T2: n_processed_responses + output.pending_responses.len() <= request_no
// 23 0 24
// T2: enters poll loop that reads stdout
// T2: put response for 23 into pending_responses
// T2: put response for 24 into pending_resposnes
// pending_responses now looks like this: Front Some(response_23) Some(response_24) Back
// T2: takes its response_24
// pending_responses now looks like this: Front Some(response_23) None Back
// T2: does the while loop below
// pending_responses now looks like this: Front Some(response_23) None Back
// T2: releases output mutex
// T1: grabs output mutex
// T1: n_processed_responses + output.pending_responses.len() > request_no
// 23 2 23
// T1: skips poll loop that reads stdout
// T1: takes its response_23
// pending_responses now looks like this: Front None None Back
// T2: does the while loop below
// pending_responses now looks like this: Front Back
// n_processed_responses now has value 25
let res = output.pending_responses[request_no - n_processed_responses]
.take()
.expect("we own this request_no, nobody else is supposed to take it");
while let Some(front) = output.pending_responses.front() {
if front.is_none() {
output.pending_responses.pop_front();
output.n_processed_responses += 1;
} else {
break;
}
}
Ok(res)
}
#[cfg(feature = "testing")]
fn record_and_log(&self, writebuf: &[u8]) {
use std::sync::atomic::Ordering;
let millis = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis();
let seq = self.dump_sequence.fetch_add(1, Ordering::Relaxed);
// these files will be collected to an allure report
let filename = format!("walredo-{millis}-{}-{seq}.walredo", writebuf.len());
let path = self.conf.tenant_path(&self.tenant_shard_id).join(&filename);
let res = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.read(true)
.open(path)
.and_then(|mut f| f.write_all(writebuf));
// trip up allowed_errors
if let Err(e) = res {
tracing::error!(target=%filename, length=writebuf.len(), "failed to write out the walredo errored input: {e}");
} else {
tracing::error!(filename, "erroring walredo input saved");
}
}
pub(crate) fn id(&self) -> u32 {
match self {
Process::Sync(p) => p.id(),
Process::Async(p) => p.id(),
}
}
#[cfg(not(feature = "testing"))]
fn record_and_log(&self, _: &[u8]) {}
}
pub(crate) fn kind(&self) -> Kind {
match self {
Process::Sync(_) => Kind::Sync,
Process::Async(_) => Kind::Async,
}
impl Drop for WalRedoProcess {
fn drop(&mut self) {
self.child
.take()
.expect("we only do this once")
.kill_and_wait(WalRedoKillCause::WalRedoProcessDrop);
// no way to wait for stderr_logger_task from Drop because that is async only
}
}

View File

@@ -1,374 +0,0 @@
use self::no_leak_child::NoLeakChild;
use crate::{
config::PageServerConf,
metrics::{WalRedoKillCause, WAL_REDO_PROCESS_COUNTERS, WAL_REDO_RECORD_COUNTER},
walrecord::NeonWalRecord,
walredo::process::{no_leak_child, protocol},
};
use anyhow::Context;
use bytes::Bytes;
use pageserver_api::{reltag::RelTag, shard::TenantShardId};
use postgres_ffi::BLCKSZ;
#[cfg(feature = "testing")]
use std::sync::atomic::AtomicUsize;
use std::{
collections::VecDeque,
process::{Command, Stdio},
time::Duration,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{debug, error, instrument, Instrument};
use utils::{lsn::Lsn, poison::Poison};
pub struct WalRedoProcess {
#[allow(dead_code)]
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
// Some() on construction, only becomes None on Drop.
child: Option<NoLeakChild>,
stdout: tokio::sync::Mutex<Poison<ProcessOutput>>,
stdin: tokio::sync::Mutex<Poison<ProcessInput>>,
/// Counter to separate same sized walredo inputs failing at the same millisecond.
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize,
}
struct ProcessInput {
stdin: tokio::process::ChildStdin,
n_requests: usize,
}
struct ProcessOutput {
stdout: tokio::process::ChildStdout,
pending_responses: VecDeque<Option<Bytes>>,
n_processed_responses: usize,
}
impl WalRedoProcess {
//
// Start postgres binary in special WAL redo mode.
//
#[instrument(skip_all,fields(pg_version=pg_version))]
pub(crate) fn launch(
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
pg_version: u32,
) -> anyhow::Result<Self> {
crate::span::debug_assert_current_span_has_tenant_id();
let pg_bin_dir_path = conf.pg_bin_dir(pg_version).context("pg_bin_dir")?; // TODO these should be infallible.
let pg_lib_dir_path = conf.pg_lib_dir(pg_version).context("pg_lib_dir")?;
use no_leak_child::NoLeakChildCommandExt;
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
// the first arg must be --wal-redo so the child process enters into walredo mode
.arg("--wal-redo")
// the child doesn't process this arg, but, having it in the argv helps indentify the
// walredo process for a particular tenant when debugging a pagserver
.args(["--tenant-shard-id", &format!("{tenant_shard_id}")])
.stdin(Stdio::piped())
.stderr(Stdio::piped())
.stdout(Stdio::piped())
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
// NB: The redo process is not trusted after we sent it the first
// walredo work. Before that, it is trusted. Specifically, we trust
// it to
// 1. close all file descriptors except stdin, stdout, stderr because
// pageserver might not be 100% diligent in setting FD_CLOEXEC on all
// the files it opens, and
// 2. to use seccomp to sandbox itself before processing the first
// walredo request.
.spawn_no_leak_child(tenant_shard_id)
.context("spawn process")?;
WAL_REDO_PROCESS_COUNTERS.started.inc();
let mut child = scopeguard::guard(child, |child| {
error!("killing wal-redo-postgres process due to a problem during launch");
child.kill_and_wait(WalRedoKillCause::Startup);
});
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let stderr = tokio::process::ChildStderr::from_std(stderr)
.context("convert to tokio::ChildStderr")?;
let stdin =
tokio::process::ChildStdin::from_std(stdin).context("convert to tokio::ChildStdin")?;
let stdout = tokio::process::ChildStdout::from_std(stdout)
.context("convert to tokio::ChildStdout")?;
// all fallible operations post-spawn are complete, so get rid of the guard
let child = scopeguard::ScopeGuard::into_inner(child);
tokio::spawn(
async move {
scopeguard::defer! {
debug!("wal-redo-postgres stderr_logger_task finished");
crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_finished.inc();
}
debug!("wal-redo-postgres stderr_logger_task started");
crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_started.inc();
use tokio::io::AsyncBufReadExt;
let mut stderr_lines = tokio::io::BufReader::new(stderr);
let mut buf = Vec::new();
let res = loop {
buf.clear();
// TODO we don't trust the process to cap its stderr length.
// Currently it can do unbounded Vec allocation.
match stderr_lines.read_until(b'\n', &mut buf).await {
Ok(0) => break Ok(()), // eof
Ok(num_bytes) => {
let output = String::from_utf8_lossy(&buf[..num_bytes]);
error!(%output, "received output");
}
Err(e) => {
break Err(e);
}
}
};
match res {
Ok(()) => (),
Err(e) => {
error!(error=?e, "failed to read from walredo stderr");
}
}
}.instrument(tracing::info_span!(parent: None, "wal-redo-postgres-stderr", pid = child.id(), tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %pg_version))
);
Ok(Self {
conf,
tenant_shard_id,
child: Some(child),
stdin: tokio::sync::Mutex::new(Poison::new(
"stdin",
ProcessInput {
stdin,
n_requests: 0,
},
)),
stdout: tokio::sync::Mutex::new(Poison::new(
"stdout",
ProcessOutput {
stdout,
pending_responses: VecDeque::new(),
n_processed_responses: 0,
},
)),
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize::default(),
})
}
pub(crate) fn id(&self) -> u32 {
self.child
.as_ref()
.expect("must not call this during Drop")
.id()
}
/// Apply given WAL records ('records') over an old page image. Returns
/// new page image.
///
/// # Cancel-Safety
///
/// Cancellation safe.
#[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), pid=%self.id()))]
pub(crate) async fn apply_wal_records(
&self,
rel: RelTag,
blknum: u32,
base_img: &Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> anyhow::Result<Bytes> {
let tag = protocol::BufferTag { rel, blknum };
// Serialize all the messages to send the WAL redo process first.
//
// This could be problematic if there are millions of records to replay,
// but in practice the number of records is usually so small that it doesn't
// matter, and it's better to keep this code simple.
//
// Most requests start with a before-image with BLCKSZ bytes, followed by
// by some other WAL records. Start with a buffer that can hold that
// comfortably.
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
protocol::build_begin_redo_for_block_msg(tag, &mut writebuf);
if let Some(img) = base_img {
protocol::build_push_page_msg(tag, img, &mut writebuf);
}
for (lsn, rec) in records.iter() {
if let NeonWalRecord::Postgres {
will_init: _,
rec: postgres_rec,
} = rec
{
protocol::build_apply_record_msg(*lsn, postgres_rec, &mut writebuf);
} else {
anyhow::bail!("tried to pass neon wal record to postgres WAL redo");
}
}
protocol::build_get_page_msg(tag, &mut writebuf);
WAL_REDO_RECORD_COUNTER.inc_by(records.len() as u64);
let Ok(res) =
tokio::time::timeout(wal_redo_timeout, self.apply_wal_records0(&writebuf)).await
else {
anyhow::bail!("WAL redo timed out");
};
if res.is_err() {
// not all of these can be caused by this particular input, however these are so rare
// in tests so capture all.
self.record_and_log(&writebuf);
}
res
}
/// # Cancel-Safety
///
/// When not polled to completion (e.g. because in `tokio::select!` another
/// branch becomes ready before this future), concurrent and subsequent
/// calls may fail due to [`utils::poison::Poison::check_and_arm`] calls.
/// Dispose of this process instance and create a new one.
async fn apply_wal_records0(&self, writebuf: &[u8]) -> anyhow::Result<Bytes> {
let request_no = {
let mut lock_guard = self.stdin.lock().await;
let mut poison_guard = lock_guard.check_and_arm()?;
let input = poison_guard.data_mut();
input
.stdin
.write_all(writebuf)
.await
.context("write to walredo stdin")?;
let request_no = input.n_requests;
input.n_requests += 1;
poison_guard.disarm();
request_no
};
// To improve walredo performance we separate sending requests and receiving
// responses. Them are protected by different mutexes (output and input).
// If thread T1, T2, T3 send requests D1, D2, D3 to walredo process
// then there is not warranty that T1 will first granted output mutex lock.
// To address this issue we maintain number of sent requests, number of processed
// responses and ring buffer with pending responses. After sending response
// (under input mutex), threads remembers request number. Then it releases
// input mutex, locks output mutex and fetch in ring buffer all responses until
// its stored request number. The it takes correspondent element from
// pending responses ring buffer and truncate all empty elements from the front,
// advancing processed responses number.
let mut lock_guard = self.stdout.lock().await;
let mut poison_guard = lock_guard.check_and_arm()?;
let output = poison_guard.data_mut();
let n_processed_responses = output.n_processed_responses;
while n_processed_responses + output.pending_responses.len() <= request_no {
// We expect the WAL redo process to respond with an 8k page image. We read it
// into this buffer.
let mut resultbuf = vec![0; BLCKSZ.into()];
output
.stdout
.read_exact(&mut resultbuf)
.await
.context("read walredo stdout")?;
output
.pending_responses
.push_back(Some(Bytes::from(resultbuf)));
}
// Replace our request's response with None in `pending_responses`.
// Then make space in the ring buffer by clearing out any seqence of contiguous
// `None`'s from the front of `pending_responses`.
// NB: We can't pop_front() because other requests' responses because another
// requester might have grabbed the output mutex before us:
// T1: grab input mutex
// T1: send request_no 23
// T1: release input mutex
// T2: grab input mutex
// T2: send request_no 24
// T2: release input mutex
// T2: grab output mutex
// T2: n_processed_responses + output.pending_responses.len() <= request_no
// 23 0 24
// T2: enters poll loop that reads stdout
// T2: put response for 23 into pending_responses
// T2: put response for 24 into pending_resposnes
// pending_responses now looks like this: Front Some(response_23) Some(response_24) Back
// T2: takes its response_24
// pending_responses now looks like this: Front Some(response_23) None Back
// T2: does the while loop below
// pending_responses now looks like this: Front Some(response_23) None Back
// T2: releases output mutex
// T1: grabs output mutex
// T1: n_processed_responses + output.pending_responses.len() > request_no
// 23 2 23
// T1: skips poll loop that reads stdout
// T1: takes its response_23
// pending_responses now looks like this: Front None None Back
// T2: does the while loop below
// pending_responses now looks like this: Front Back
// n_processed_responses now has value 25
let res = output.pending_responses[request_no - n_processed_responses]
.take()
.expect("we own this request_no, nobody else is supposed to take it");
while let Some(front) = output.pending_responses.front() {
if front.is_none() {
output.pending_responses.pop_front();
output.n_processed_responses += 1;
} else {
break;
}
}
poison_guard.disarm();
Ok(res)
}
#[cfg(feature = "testing")]
fn record_and_log(&self, writebuf: &[u8]) {
use std::sync::atomic::Ordering;
let millis = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis();
let seq = self.dump_sequence.fetch_add(1, Ordering::Relaxed);
// these files will be collected to an allure report
let filename = format!("walredo-{millis}-{}-{seq}.walredo", writebuf.len());
let path = self.conf.tenant_path(&self.tenant_shard_id).join(&filename);
use std::io::Write;
let res = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.read(true)
.open(path)
.and_then(|mut f| f.write_all(writebuf));
// trip up allowed_errors
if let Err(e) = res {
tracing::error!(target=%filename, length=writebuf.len(), "failed to write out the walredo errored input: {e}");
} else {
tracing::error!(filename, "erroring walredo input saved");
}
}
#[cfg(not(feature = "testing"))]
fn record_and_log(&self, _: &[u8]) {}
}
impl Drop for WalRedoProcess {
fn drop(&mut self) {
self.child
.take()
.expect("we only do this once")
.kill_and_wait(WalRedoKillCause::WalRedoProcessDrop);
// no way to wait for stderr_logger_task from Drop because that is async only
}
}

View File

@@ -1,405 +0,0 @@
use self::no_leak_child::NoLeakChild;
use crate::{
config::PageServerConf,
metrics::{WalRedoKillCause, WAL_REDO_PROCESS_COUNTERS, WAL_REDO_RECORD_COUNTER},
walrecord::NeonWalRecord,
walredo::process::{no_leak_child, protocol},
};
use anyhow::Context;
use bytes::Bytes;
use nix::poll::{PollFd, PollFlags};
use pageserver_api::{reltag::RelTag, shard::TenantShardId};
use postgres_ffi::BLCKSZ;
use std::os::fd::AsRawFd;
#[cfg(feature = "testing")]
use std::sync::atomic::AtomicUsize;
use std::{
collections::VecDeque,
io::{Read, Write},
process::{ChildStdin, ChildStdout, Command, Stdio},
sync::{Mutex, MutexGuard},
time::Duration,
};
use tracing::{debug, error, instrument, Instrument};
use utils::{lsn::Lsn, nonblock::set_nonblock};
pub struct WalRedoProcess {
#[allow(dead_code)]
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
// Some() on construction, only becomes None on Drop.
child: Option<NoLeakChild>,
stdout: Mutex<ProcessOutput>,
stdin: Mutex<ProcessInput>,
/// Counter to separate same sized walredo inputs failing at the same millisecond.
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize,
}
struct ProcessInput {
stdin: ChildStdin,
n_requests: usize,
}
struct ProcessOutput {
stdout: ChildStdout,
pending_responses: VecDeque<Option<Bytes>>,
n_processed_responses: usize,
}
impl WalRedoProcess {
//
// Start postgres binary in special WAL redo mode.
//
#[instrument(skip_all,fields(pg_version=pg_version))]
pub(crate) fn launch(
conf: &'static PageServerConf,
tenant_shard_id: TenantShardId,
pg_version: u32,
) -> anyhow::Result<Self> {
crate::span::debug_assert_current_span_has_tenant_id();
let pg_bin_dir_path = conf.pg_bin_dir(pg_version).context("pg_bin_dir")?; // TODO these should be infallible.
let pg_lib_dir_path = conf.pg_lib_dir(pg_version).context("pg_lib_dir")?;
use no_leak_child::NoLeakChildCommandExt;
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
// the first arg must be --wal-redo so the child process enters into walredo mode
.arg("--wal-redo")
// the child doesn't process this arg, but, having it in the argv helps indentify the
// walredo process for a particular tenant when debugging a pagserver
.args(["--tenant-shard-id", &format!("{tenant_shard_id}")])
.stdin(Stdio::piped())
.stderr(Stdio::piped())
.stdout(Stdio::piped())
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
// NB: The redo process is not trusted after we sent it the first
// walredo work. Before that, it is trusted. Specifically, we trust
// it to
// 1. close all file descriptors except stdin, stdout, stderr because
// pageserver might not be 100% diligent in setting FD_CLOEXEC on all
// the files it opens, and
// 2. to use seccomp to sandbox itself before processing the first
// walredo request.
.spawn_no_leak_child(tenant_shard_id)
.context("spawn process")?;
WAL_REDO_PROCESS_COUNTERS.started.inc();
let mut child = scopeguard::guard(child, |child| {
error!("killing wal-redo-postgres process due to a problem during launch");
child.kill_and_wait(WalRedoKillCause::Startup);
});
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let stderr = tokio::process::ChildStderr::from_std(stderr)
.context("convert to tokio::ChildStderr")?;
macro_rules! set_nonblock_or_log_err {
($file:ident) => {{
let res = set_nonblock($file.as_raw_fd());
if let Err(e) = &res {
error!(error = %e, file = stringify!($file), pid = child.id(), "set_nonblock failed");
}
res
}};
}
set_nonblock_or_log_err!(stdin)?;
set_nonblock_or_log_err!(stdout)?;
// all fallible operations post-spawn are complete, so get rid of the guard
let child = scopeguard::ScopeGuard::into_inner(child);
tokio::spawn(
async move {
scopeguard::defer! {
debug!("wal-redo-postgres stderr_logger_task finished");
crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_finished.inc();
}
debug!("wal-redo-postgres stderr_logger_task started");
crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_started.inc();
use tokio::io::AsyncBufReadExt;
let mut stderr_lines = tokio::io::BufReader::new(stderr);
let mut buf = Vec::new();
let res = loop {
buf.clear();
// TODO we don't trust the process to cap its stderr length.
// Currently it can do unbounded Vec allocation.
match stderr_lines.read_until(b'\n', &mut buf).await {
Ok(0) => break Ok(()), // eof
Ok(num_bytes) => {
let output = String::from_utf8_lossy(&buf[..num_bytes]);
error!(%output, "received output");
}
Err(e) => {
break Err(e);
}
}
};
match res {
Ok(()) => (),
Err(e) => {
error!(error=?e, "failed to read from walredo stderr");
}
}
}.instrument(tracing::info_span!(parent: None, "wal-redo-postgres-stderr", pid = child.id(), tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %pg_version))
);
Ok(Self {
conf,
tenant_shard_id,
child: Some(child),
stdin: Mutex::new(ProcessInput {
stdin,
n_requests: 0,
}),
stdout: Mutex::new(ProcessOutput {
stdout,
pending_responses: VecDeque::new(),
n_processed_responses: 0,
}),
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize::default(),
})
}
pub(crate) fn id(&self) -> u32 {
self.child
.as_ref()
.expect("must not call this during Drop")
.id()
}
// Apply given WAL records ('records') over an old page image. Returns
// new page image.
//
#[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), pid=%self.id()))]
pub(crate) async fn apply_wal_records(
&self,
rel: RelTag,
blknum: u32,
base_img: &Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> anyhow::Result<Bytes> {
let tag = protocol::BufferTag { rel, blknum };
let input = self.stdin.lock().unwrap();
// Serialize all the messages to send the WAL redo process first.
//
// This could be problematic if there are millions of records to replay,
// but in practice the number of records is usually so small that it doesn't
// matter, and it's better to keep this code simple.
//
// Most requests start with a before-image with BLCKSZ bytes, followed by
// by some other WAL records. Start with a buffer that can hold that
// comfortably.
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
protocol::build_begin_redo_for_block_msg(tag, &mut writebuf);
if let Some(img) = base_img {
protocol::build_push_page_msg(tag, img, &mut writebuf);
}
for (lsn, rec) in records.iter() {
if let NeonWalRecord::Postgres {
will_init: _,
rec: postgres_rec,
} = rec
{
protocol::build_apply_record_msg(*lsn, postgres_rec, &mut writebuf);
} else {
anyhow::bail!("tried to pass neon wal record to postgres WAL redo");
}
}
protocol::build_get_page_msg(tag, &mut writebuf);
WAL_REDO_RECORD_COUNTER.inc_by(records.len() as u64);
let res = self.apply_wal_records0(&writebuf, input, wal_redo_timeout);
if res.is_err() {
// not all of these can be caused by this particular input, however these are so rare
// in tests so capture all.
self.record_and_log(&writebuf);
}
res
}
fn apply_wal_records0(
&self,
writebuf: &[u8],
input: MutexGuard<ProcessInput>,
wal_redo_timeout: Duration,
) -> anyhow::Result<Bytes> {
let mut proc = { input }; // TODO: remove this legacy rename, but this keep the patch small.
let mut nwrite = 0usize;
while nwrite < writebuf.len() {
let mut stdin_pollfds = [PollFd::new(&proc.stdin, PollFlags::POLLOUT)];
let n = loop {
match nix::poll::poll(&mut stdin_pollfds[..], wal_redo_timeout.as_millis() as i32) {
Err(nix::errno::Errno::EINTR) => continue,
res => break res,
}
}?;
if n == 0 {
anyhow::bail!("WAL redo timed out");
}
// If 'stdin' is writeable, do write.
let in_revents = stdin_pollfds[0].revents().unwrap();
if in_revents & (PollFlags::POLLERR | PollFlags::POLLOUT) != PollFlags::empty() {
nwrite += proc.stdin.write(&writebuf[nwrite..])?;
}
if in_revents.contains(PollFlags::POLLHUP) {
// We still have more data to write, but the process closed the pipe.
anyhow::bail!("WAL redo process closed its stdin unexpectedly");
}
}
let request_no = proc.n_requests;
proc.n_requests += 1;
drop(proc);
// To improve walredo performance we separate sending requests and receiving
// responses. Them are protected by different mutexes (output and input).
// If thread T1, T2, T3 send requests D1, D2, D3 to walredo process
// then there is not warranty that T1 will first granted output mutex lock.
// To address this issue we maintain number of sent requests, number of processed
// responses and ring buffer with pending responses. After sending response
// (under input mutex), threads remembers request number. Then it releases
// input mutex, locks output mutex and fetch in ring buffer all responses until
// its stored request number. The it takes correspondent element from
// pending responses ring buffer and truncate all empty elements from the front,
// advancing processed responses number.
let mut output = self.stdout.lock().unwrap();
let n_processed_responses = output.n_processed_responses;
while n_processed_responses + output.pending_responses.len() <= request_no {
// We expect the WAL redo process to respond with an 8k page image. We read it
// into this buffer.
let mut resultbuf = vec![0; BLCKSZ.into()];
let mut nresult: usize = 0; // # of bytes read into 'resultbuf' so far
while nresult < BLCKSZ.into() {
let mut stdout_pollfds = [PollFd::new(&output.stdout, PollFlags::POLLIN)];
// We do two things simultaneously: reading response from stdout
// and forward any logging information that the child writes to its stderr to the page server's log.
let n = loop {
match nix::poll::poll(
&mut stdout_pollfds[..],
wal_redo_timeout.as_millis() as i32,
) {
Err(nix::errno::Errno::EINTR) => continue,
res => break res,
}
}?;
if n == 0 {
anyhow::bail!("WAL redo timed out");
}
// If we have some data in stdout, read it to the result buffer.
let out_revents = stdout_pollfds[0].revents().unwrap();
if out_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() {
nresult += output.stdout.read(&mut resultbuf[nresult..])?;
}
if out_revents.contains(PollFlags::POLLHUP) {
anyhow::bail!("WAL redo process closed its stdout unexpectedly");
}
}
output
.pending_responses
.push_back(Some(Bytes::from(resultbuf)));
}
// Replace our request's response with None in `pending_responses`.
// Then make space in the ring buffer by clearing out any seqence of contiguous
// `None`'s from the front of `pending_responses`.
// NB: We can't pop_front() because other requests' responses because another
// requester might have grabbed the output mutex before us:
// T1: grab input mutex
// T1: send request_no 23
// T1: release input mutex
// T2: grab input mutex
// T2: send request_no 24
// T2: release input mutex
// T2: grab output mutex
// T2: n_processed_responses + output.pending_responses.len() <= request_no
// 23 0 24
// T2: enters poll loop that reads stdout
// T2: put response for 23 into pending_responses
// T2: put response for 24 into pending_resposnes
// pending_responses now looks like this: Front Some(response_23) Some(response_24) Back
// T2: takes its response_24
// pending_responses now looks like this: Front Some(response_23) None Back
// T2: does the while loop below
// pending_responses now looks like this: Front Some(response_23) None Back
// T2: releases output mutex
// T1: grabs output mutex
// T1: n_processed_responses + output.pending_responses.len() > request_no
// 23 2 23
// T1: skips poll loop that reads stdout
// T1: takes its response_23
// pending_responses now looks like this: Front None None Back
// T2: does the while loop below
// pending_responses now looks like this: Front Back
// n_processed_responses now has value 25
let res = output.pending_responses[request_no - n_processed_responses]
.take()
.expect("we own this request_no, nobody else is supposed to take it");
while let Some(front) = output.pending_responses.front() {
if front.is_none() {
output.pending_responses.pop_front();
output.n_processed_responses += 1;
} else {
break;
}
}
Ok(res)
}
#[cfg(feature = "testing")]
fn record_and_log(&self, writebuf: &[u8]) {
use std::sync::atomic::Ordering;
let millis = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis();
let seq = self.dump_sequence.fetch_add(1, Ordering::Relaxed);
// these files will be collected to an allure report
let filename = format!("walredo-{millis}-{}-{seq}.walredo", writebuf.len());
let path = self.conf.tenant_path(&self.tenant_shard_id).join(&filename);
let res = std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.read(true)
.open(path)
.and_then(|mut f| f.write_all(writebuf));
// trip up allowed_errors
if let Err(e) = res {
tracing::error!(target=%filename, length=writebuf.len(), "failed to write out the walredo errored input: {e}");
} else {
tracing::error!(filename, "erroring walredo input saved");
}
}
#[cfg(not(feature = "testing"))]
fn record_and_log(&self, _: &[u8]) {}
}
impl Drop for WalRedoProcess {
fn drop(&mut self) {
self.child
.take()
.expect("we only do this once")
.kill_and_wait(WalRedoKillCause::WalRedoProcessDrop);
// no way to wait for stderr_logger_task from Drop because that is async only
}
}

View File

@@ -49,6 +49,8 @@ char *neon_auth_token;
int readahead_buffer_size = 128;
int flush_every_n_requests = 8;
int neon_protocol_version;
static int n_reconnect_attempts = 0;
static int max_reconnect_attempts = 60;
static int stripe_size;
@@ -844,6 +846,14 @@ pg_init_libpagestore(void)
PGC_USERSET,
0, /* no flags required */
NULL, (GucIntAssignHook) &readahead_buffer_resize, NULL);
DefineCustomIntVariable("neon.protocol_version",
"Version of compute<->page server protocol",
NULL,
&neon_protocol_version,
NEON_PROTOCOL_VERSION, 1, INT_MAX,
PGC_USERSET,
0, /* no flags required */
NULL, NULL, NULL);
relsize_hash_init();

View File

@@ -28,10 +28,17 @@
#define MAX_SHARDS 128
#define MAX_PAGESERVER_CONNSTRING_SIZE 256
/*
* Right now protocal version is not set to the server.
* So it is ciritical that format of existed commands is not changed.
* New protocl versions can just add new commands.
*/
#define NEON_PROTOCOL_VERSION 2
typedef enum
{
/* pagestore_client -> pagestore */
T_NeonExistsRequest = 0,
T_NeonExistsRequest = 10, /* new protocol message tags start from 10 */
T_NeonNblocksRequest,
T_NeonGetPageRequest,
T_NeonDbSizeRequest,
@@ -72,14 +79,20 @@ typedef enum {
/*
* supertype of all the Neon*Request structs below
*
* If 'latest' is true, we are requesting the latest page version, and 'lsn'
* In old version of Neon we have 'latest' flag indicating that we are requesting the latest page version, and 'lsn'
* is just a hint to the server that we know there are no versions of the page
* (or relation size, for exists/nblocks requests) later than the 'lsn'.
*
* But it doesn't work for hot-standby replica because it may be not at the latest LSN position.
* So we need to be able to specify upper boundary for LSN which page server can send to us.
* This is why 'latest' flag is replaced with 'horizon'. MAX_LSN=~0 value of 'horizon' means that we are requesting latest version.
* If we need version on exact LSN (for static RO replicas), 'horizon' should be set to 0: in this case range [lsn,lsn] is used by page server.
* Otherwise for hot-standby replica we specify in 'horizon' current replay position.
*/
typedef struct
{
NeonMessageTag tag;
bool latest; /* if true, request latest page version */
XLogRecPtr horizon; /* upper boundary for page LSN */
XLogRecPtr lsn; /* request page version @ this LSN */
} NeonRequest;
@@ -193,6 +206,7 @@ extern int readahead_buffer_size;
extern char *neon_timeline;
extern char *neon_tenant;
extern int32 max_cluster_size;
extern int neon_protocol_version;
extern shardno_t get_shard_number(BufferTag* tag);

View File

@@ -110,6 +110,20 @@ static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
static bool neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id);
static bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) = NULL;
#define MAX_LSN ((XLogRecPtr)~0)
/*
* There are three kinds of get_page :
* 1. Master compute: get the latest page not older than specified LSN (horizon=Lsn::MAX)
* 2. RO replica: get the latest page not newer than current WAL position replica already applied (horizon=GetXLogReplayRecPtr(NULL))
* 3. Snapshot: get latest page not new than specified LSN (horizon=request_lsn)
*/
static XLogRecPtr
neon_get_horizon(bool latest)
{
return latest ? MAX_LSN : RecoveryInProgress() ? GetXLogReplayRecPtr(NULL) : InvalidXLogRecPtr; /* horizon=InvalidXlogRecPtr is replaced with request_lsn at PS */
}
/*
* Prefetch implementation:
*
@@ -687,9 +701,10 @@ static void
prefetch_do_request(PrefetchRequest *slot, bool *force_latest, XLogRecPtr *force_lsn)
{
bool found;
bool latest;
NeonGetPageRequest request = {
.req.tag = T_NeonGetPageRequest,
.req.latest = false,
.req.horizon = 0,
.req.lsn = 0,
.rinfo = BufTagGetNRelFileInfo(slot->buftag),
.forknum = slot->buftag.forkNum,
@@ -699,13 +714,13 @@ prefetch_do_request(PrefetchRequest *slot, bool *force_latest, XLogRecPtr *force
if (force_lsn && force_latest)
{
request.req.lsn = *force_lsn;
request.req.latest = *force_latest;
latest = *force_latest;
slot->actual_request_lsn = slot->effective_request_lsn = *force_lsn;
}
else
{
XLogRecPtr lsn = neon_get_request_lsn(
&request.req.latest,
&latest,
BufTagGetNRelFileInfo(slot->buftag),
slot->buftag.forkNum,
slot->buftag.blockNum
@@ -733,6 +748,7 @@ prefetch_do_request(PrefetchRequest *slot, bool *force_latest, XLogRecPtr *force
prefetch_lsn = Max(prefetch_lsn, lsn);
slot->effective_request_lsn = prefetch_lsn;
}
request.req.horizon = neon_get_horizon(latest);
Assert(slot->response == NULL);
Assert(slot->my_ring_index == MyPState->ring_unused);
@@ -997,7 +1013,19 @@ nm_pack_request(NeonRequest *msg)
StringInfoData s;
initStringInfo(&s);
pq_sendbyte(&s, msg->tag);
if (neon_protocol_version >= 2)
{
pq_sendbyte(&s, msg->tag);
pq_sendint64(&s, msg->horizon);
}
else
{
/* Old protocol with latest flag */
pq_sendbyte(&s, msg->tag - T_NeonExistsRequest); /* old protocol command tags start from zero */
pq_sendbyte(&s, msg->horizon == MAX_LSN);
}
pq_sendint64(&s, msg->lsn);
switch (messageTag(msg))
{
@@ -1006,8 +1034,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonExistsRequest *msg_req = (NeonExistsRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, NInfoGetSpcOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetDbOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetRelNumber(msg_req->rinfo));
@@ -1019,8 +1045,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonNblocksRequest *msg_req = (NeonNblocksRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, NInfoGetSpcOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetDbOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetRelNumber(msg_req->rinfo));
@@ -1032,8 +1056,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonDbSizeRequest *msg_req = (NeonDbSizeRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, msg_req->dbNode);
break;
@@ -1042,8 +1064,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonGetPageRequest *msg_req = (NeonGetPageRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, NInfoGetSpcOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetDbOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetRelNumber(msg_req->rinfo));
@@ -1057,8 +1077,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonGetSlruSegmentRequest *msg_req = (NeonGetSlruSegmentRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendbyte(&s, msg_req->kind);
pq_sendint32(&s, msg_req->segno);
@@ -1209,7 +1227,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"rinfo\": \"%u/%u/%u\"", RelFileInfoFmt(msg_req->rinfo));
appendStringInfo(&s, ", \"forknum\": %d", msg_req->forknum);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1222,7 +1240,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"rinfo\": \"%u/%u/%u\"", RelFileInfoFmt(msg_req->rinfo));
appendStringInfo(&s, ", \"forknum\": %d", msg_req->forknum);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1236,7 +1254,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"forknum\": %d", msg_req->forknum);
appendStringInfo(&s, ", \"blkno\": %u", msg_req->blkno);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1247,7 +1265,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfoString(&s, "{\"type\": \"NeonDbSizeRequest\"");
appendStringInfo(&s, ", \"dbnode\": \"%u\"", msg_req->dbNode);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1259,7 +1277,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"kind\": %u", msg_req->kind);
appendStringInfo(&s, ", \"segno\": %u", msg_req->segno);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1664,7 +1682,7 @@ neon_exists(SMgrRelation reln, ForkNumber forkNum)
{
NeonExistsRequest request = {
.req.tag = T_NeonExistsRequest,
.req.latest = latest,
.req.horizon = neon_get_horizon(latest),
.req.lsn = request_lsn,
.rinfo = InfoFromSMgrRel(reln),
.forknum = forkNum};
@@ -2474,7 +2492,7 @@ neon_nblocks(SMgrRelation reln, ForkNumber forknum)
{
NeonNblocksRequest request = {
.req.tag = T_NeonNblocksRequest,
.req.latest = latest,
.req.horizon = neon_get_horizon(latest),
.req.lsn = request_lsn,
.rinfo = InfoFromSMgrRel(reln),
.forknum = forknum,
@@ -2531,7 +2549,7 @@ neon_dbsize(Oid dbNode)
{
NeonDbSizeRequest request = {
.req.tag = T_NeonDbSizeRequest,
.req.latest = latest,
.req.horizon = neon_get_horizon(latest),
.req.lsn = request_lsn,
.dbNode = dbNode,
};
@@ -2827,7 +2845,7 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf
NeonResponse *resp;
NeonGetSlruSegmentRequest request = {
.req.tag = T_NeonGetSlruSegmentRequest,
.req.latest = false,
.req.horizon = InvalidXLogRecPtr,
.req.lsn = request_lsn,
.kind = kind,
@@ -2980,7 +2998,7 @@ neon_extend_rel_size(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
NeonNblocksRequest request = {
.req = (NeonRequest) {
.lsn = end_recptr,
.latest = false,
.horizon = neon_get_horizon(false),
.tag = T_NeonNblocksRequest,
},
.rinfo = rinfo,

19
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "aiohttp"
@@ -1191,13 +1191,13 @@ files = [
[[package]]
name = "idna"
version = "3.7"
version = "3.3"
description = "Internationalized Domain Names in Applications (IDNA)"
optional = false
python-versions = ">=3.5"
files = [
{file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"},
{file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
{file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
{file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
]
[[package]]
@@ -2182,7 +2182,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -2653,16 +2652,6 @@ files = [
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"},
{file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"},
{file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},

View File

@@ -44,7 +44,6 @@ ipnet.workspace = true
itertools.workspace = true
lasso = { workspace = true, features = ["multi-threaded"] }
md5.workspace = true
measured = { workspace = true, features = ["lasso"] }
metrics.workspace = true
once_cell.workspace = true
opentelemetry.workspace = true

View File

@@ -2,15 +2,8 @@ mod classic;
mod hacks;
mod link;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use ipnet::{Ipv4Net, Ipv6Net};
pub use link::LinkAuthError;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::validate_password_and_exchange;
@@ -20,10 +13,9 @@ use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::{AuthSecret, NodeInfo};
use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::metrics::{AUTH_RATE_LIMIT_HITS, ENDPOINTS_AUTH_RATE_LIMITED};
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, RateBucketInfo};
use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
@@ -35,7 +27,10 @@ use crate::{
},
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
@@ -181,51 +176,17 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
}
}
#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
pub struct MaskedIp(IpAddr);
impl MaskedIp {
fn new(value: IpAddr, prefix: u8) -> Self {
match value {
IpAddr::V4(v4) => Self(IpAddr::V4(
Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
)),
IpAddr::V6(v6) => Self(IpAddr::V6(
Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
)),
}
}
}
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
impl RateBucketInfo {
/// All of these are per endpoint-maskedip pair.
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
///
/// First bucket: 1000mcpus total per endpoint-ip pair
/// * 4096000 requests per second with 1 hash rounds.
/// * 1000 requests per second with 4096 hash rounds.
/// * 6.8 requests per second with 600000 hash rounds.
pub const DEFAULT_AUTH_SET: [Self; 3] = [
Self::new(1000 * 4096, Duration::from_secs(1)),
Self::new(600 * 4096, Duration::from_secs(60)),
Self::new(300 * 4096, Duration::from_secs(600)),
];
}
impl AuthenticationConfig {
pub fn check_rate_limit(
&self,
ctx: &mut RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
) -> auth::Result<AuthSecret> {
// we have validated the endpoint exists, so let's intern it.
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
let endpoint_int = EndpointIdInt::from(endpoint);
// only count the full hash count if password hack or websocket flow.
// in other words, if proxy needs to run the hashing
@@ -240,25 +201,17 @@ impl AuthenticationConfig {
1
};
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet),
),
password_weight,
);
let limit_not_exceeded = self
.rate_limiter
.check((endpoint_int, ctx.peer_addr), password_weight);
if !limit_not_exceeded {
warn!(
enabled = self.rate_limiter_enabled,
"rate limiting authentication"
);
Metrics::get().proxy.requests_auth_rate_limits_total.inc();
Metrics::get()
.proxy
.endpoints_auth_rate_limits
.get_metric()
.measure(endpoint);
AUTH_RATE_LIMIT_HITS.inc();
ENDPOINTS_AUTH_RATE_LIMITED.measure(endpoint);
if self.rate_limiter_enabled {
return Err(auth::AuthError::too_many_connections());
@@ -314,7 +267,6 @@ async fn auth_quirks(
let secret = match secret {
Some(secret) => config.check_rate_limit(
ctx,
config,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
@@ -517,7 +469,7 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
#[cfg(test)]
mod tests {
use std::{net::IpAddr, sync::Arc, time::Duration};
use std::sync::Arc;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
@@ -530,7 +482,7 @@ mod tests {
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use crate::{
auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
auth::{ComputeUserInfoMaybeEndpoint, IpPattern},
config::AuthenticationConfig,
console::{
self,
@@ -539,12 +491,12 @@ mod tests {
},
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::RateBucketInfo,
rate_limiter::{AuthRateLimiter, RateBucketInfo},
scram::ServerSecret,
stream::{PqStream, Stream},
};
use super::{auth_quirks, AuthRateLimiter};
use super::auth_quirks;
struct Auth {
ips: Vec<IpPattern>,
@@ -585,7 +537,6 @@ mod tests {
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
@@ -597,51 +548,6 @@ mod tests {
}
}
#[test]
fn masked_ip() {
let ip_a = IpAddr::V4([127, 0, 0, 1].into());
let ip_b = IpAddr::V4([127, 0, 0, 2].into());
let ip_c = IpAddr::V4([192, 168, 1, 101].into());
let ip_d = IpAddr::V4([192, 168, 1, 102].into());
let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
}
#[test]
fn test_default_auth_rate_limit_set() {
// these values used to exceed u32::MAX
assert_eq!(
RateBucketInfo::DEFAULT_AUTH_SET,
[
RateBucketInfo {
interval: Duration::from_secs(1),
max_rpi: 1000 * 4096,
},
RateBucketInfo {
interval: Duration::from_secs(60),
max_rpi: 600 * 4096 * 60,
},
RateBucketInfo {
interval: Duration::from_secs(600),
max_rpi: 300 * 4096 * 600,
}
]
);
for x in RateBucketInfo::DEFAULT_AUTH_SET {
let y = x.to_string().parse().unwrap();
assert_eq!(x, y);
}
}
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);

View File

@@ -4,7 +4,7 @@ use crate::{
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::{Metrics, SniKind},
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
@@ -144,22 +144,21 @@ impl ComputeUserInfoMaybeEndpoint {
ctx.set_endpoint_id(ep.clone());
}
let metrics = Metrics::get();
info!(%user, "credentials");
if sni.is_some() {
info!("Connection with sni");
metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["sni"])
.inc();
} else if endpoint.is_some() {
metrics
.proxy
.accepted_connections_by_sni
.inc(SniKind::NoSni);
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["no_sni"])
.inc();
info!("Connection without sni");
} else {
metrics
.proxy
.accepted_connections_by_sni
.inc(SniKind::PasswordHack);
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["password_hack"])
.inc();
info!("Connection with password hack");
}

View File

@@ -9,13 +9,15 @@ use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled};
use proxy::proxy::run_until_cancelled;
use proxy::{BranchId, EndpointId, ProjectId};
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
use clap::Arg;
use futures::TryFutureExt;
use proxy::console::messages::MetricsAuxInfo;
use proxy::stream::{PqStream, Stream};
use tokio::io::{AsyncRead, AsyncWrite};
@@ -174,12 +176,7 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestMonitoring::new(
session_id,
peer_addr.ip(),
proxy::metrics::Protocol::SniRouter,
"sni",
);
let ctx = RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
}
.unwrap_or_else(|e| {
@@ -202,7 +199,6 @@ async fn task_main(
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &mut RequestMonitoring,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
@@ -232,10 +228,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls: Box::new(raw.upgrade(tls_config).await?),
tls_server_end_point,
})
}
@@ -258,7 +251,7 @@ async fn handle_client(
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&mut ctx, stream, tls_config, tls_server_end_point).await?;
let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
@@ -275,15 +268,18 @@ async fn handle_client(
info!("destination: {}", destination);
let mut client = tokio::net::TcpStream::connect(destination).await?;
let client = tokio::net::TcpStream::connect(destination).await?;
let metrics_aux: MetricsAuxInfo = MetricsAuxInfo {
endpoint_id: (&EndpointId::from("")).into(),
project_id: (&ProjectId::from("")).into(),
branch_id: (&BranchId::from("")).into(),
cold_start_info: proxy::console::messages::ColdStartInfo::Unknown,
};
// doesn't yet matter as pg-sni-router doesn't report analytics logs
ctx.set_success();
ctx.log();
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = copy_bidirectional_client_compute(&mut tls_stream, &mut client).await?;
Ok(())
proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await
}

View File

@@ -7,7 +7,6 @@ use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
@@ -19,10 +18,11 @@ use proxy::config::ProjectInfoCacheOptions;
use proxy::console;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT;
use proxy::rate_limiter::AuthRateLimiter;
use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
@@ -42,7 +42,6 @@ use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
use tracing::Instrument;
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
project_git_version!(GIT_VERSION);
@@ -132,8 +131,14 @@ struct ProxyCliArgs {
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
require_client_ip: bool,
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_dynamic_rate_limiter: bool,
/// Rate limit algorithm. Makes sense only if `disable_rate_limiter` is `false`.
#[clap(value_enum, long, default_value_t = proxy::rate_limiter::RateLimitAlgorithm::Aimd)]
rate_limit_algorithm: proxy::rate_limiter::RateLimitAlgorithm,
/// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error.
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
rate_limiter_timeout: tokio::time::Duration,
/// Endpoint rate limiter max number of requests per second.
///
/// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
@@ -146,12 +151,14 @@ struct ProxyCliArgs {
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
#[clap(long, default_value_t = 100)]
initial_limit: usize,
#[clap(flatten)]
aimd_config: proxy::rate_limiter::AimdConfig,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
@@ -182,9 +189,7 @@ struct ProxyCliArgs {
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
/// cache for all valid endpoints
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
endpoint_cache_config: String,
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
@@ -244,18 +249,14 @@ async fn main() -> anyhow::Result<()> {
info!("Version: {GIT_VERSION}");
info!("Build_tag: {BUILD_TAG}");
let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
revision: GIT_VERSION,
build_tag: BUILD_TAG,
});
::metrics::set_build_info_metric(GIT_VERSION, BUILD_TAG);
let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
Ok(t) => Some(t),
Err(e) => {
tracing::error!(error = ?e, "could not start jemalloc metrics loop");
None
match proxy::jemalloc::MetricRecorder::new(prometheus::default_registry()) {
Ok(t) => {
t.start();
}
};
Err(e) => tracing::error!(error = ?e, "could not start jemalloc metrics loop"),
}
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
@@ -295,27 +296,27 @@ async fn main() -> anyhow::Result<()> {
),
aws_credentials_provider,
));
let regional_redis_client = match (args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host,
port,
elasticache_credentials_provider.clone(),
let redis_notifications_client =
match (args.redis_notifications, (args.redis_host, args.redis_port)) {
(Some(url), _) => {
info!("Starting redis notifications listener ({url})");
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
}
(None, (Some(host), Some(port))) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host,
port,
elasticache_credentials_provider.clone(),
),
),
),
(None, None) => {
warn!("Redis events from console are disabled");
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
};
let redis_notifications_client = if let Some(url) = args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
} else {
regional_redis_client.clone()
};
(None, (None, None)) => {
warn!("Redis is disabled");
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
};
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
@@ -334,7 +335,8 @@ async fn main() -> anyhow::Result<()> {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &regional_redis_client {
// let redis_notifications_client = redis_notifications_client.map(|x| Box::leak(Box::new(x)));
let redis_publisher = match &redis_notifications_client {
Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
redis_publisher.clone(),
args.region.clone(),
@@ -347,7 +349,7 @@ async fn main() -> anyhow::Result<()> {
>::new(
cancel_map.clone(),
redis_publisher,
proxy::metrics::CancellationSource::FromClient,
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT,
));
// client facing tasks. these will exit on error or on cancellation
@@ -385,14 +387,7 @@ async fn main() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone()));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
jemalloc,
neon_metrics,
proxy: proxy::metrics::Metrics::get(),
},
));
maintenance_tasks.spawn(http::health_server::task_main(http_listener));
maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
if let Some(metrics_config) = &config.metric_collection {
@@ -409,19 +404,13 @@ async fn main() -> anyhow::Result<()> {
if let Some(redis_notifications_client) = redis_notifications_client {
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(
redis_notifications_client,
redis_notifications_client.clone(),
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(async move { cache.do_read(con).await }.instrument(span));
}
}
}
@@ -487,27 +476,27 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
and metric-collection-interval must be specified"
),
};
if !args.disable_dynamic_rate_limiter {
bail!("dynamic rate limiter should be disabled");
}
let rate_limiter_config = RateLimiterConfig {
disable: args.disable_dynamic_rate_limiter,
algorithm: args.rate_limit_algorithm,
timeout: args.rate_limiter_timeout,
initial_limit: args.initial_limit,
aimd_config: Some(args.aimd_config),
};
let auth_backend = match &args.auth_backend {
AuthBackend::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::WakeComputeLockOptions {
@@ -518,20 +507,13 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
} = args.wake_compute_lock.parse()?;
info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(
console::locks::ApiLocks::new(
"wake_compute_lock",
permits,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)
.unwrap(),
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout)
.unwrap(),
));
tokio::spawn(locks.garbage_collect_worker());
tokio::spawn(locks.garbage_collect_worker(epoch));
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
let api = console::provider::neon::Api::new(endpoint, caches, locks);
let api = console::provider::ConsoleBackend::Console(api);
@@ -564,7 +546,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();

View File

@@ -1,5 +1,4 @@
pub mod common;
pub mod endpoints;
pub mod project_info;
mod timed_lru;

View File

@@ -1,233 +0,0 @@
use std::{
convert::Infallible,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use dashmap::DashSet;
use redis::{
streams::{StreamReadOptions, StreamReadReply},
AsyncCommands, FromRedisValue, Value,
};
use serde::Deserialize;
use tokio::sync::Mutex;
use tracing::info;
use crate::{
config::EndpointCacheConfig,
context::RequestMonitoring,
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
metrics::{Metrics, RedisErrors},
rate_limiter::GlobalRateLimiter,
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
EndpointId,
};
#[derive(Deserialize, Debug, Clone)]
pub struct ControlPlaneEventKey {
endpoint_created: Option<EndpointCreated>,
branch_created: Option<BranchCreated>,
project_created: Option<ProjectCreated>,
}
#[derive(Deserialize, Debug, Clone)]
struct EndpointCreated {
endpoint_id: String,
}
#[derive(Deserialize, Debug, Clone)]
struct BranchCreated {
branch_id: String,
}
#[derive(Deserialize, Debug, Clone)]
struct ProjectCreated {
project_id: String,
}
pub struct EndpointsCache {
config: EndpointCacheConfig,
endpoints: DashSet<EndpointIdInt>,
branches: DashSet<BranchIdInt>,
projects: DashSet<ProjectIdInt>,
ready: AtomicBool,
limiter: Arc<Mutex<GlobalRateLimiter>>,
}
impl EndpointsCache {
pub fn new(config: EndpointCacheConfig) -> Self {
Self {
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
config.limiter_info.clone(),
))),
config,
endpoints: DashSet::new(),
branches: DashSet::new(),
projects: DashSet::new(),
ready: AtomicBool::new(false),
}
}
pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
if !self.ready.load(Ordering::Acquire) {
return true;
}
// If cache is disabled, just collect the metrics and return.
if self.config.disable_cache {
let rejected = self.should_reject(endpoint);
ctx.set_rejected(rejected);
info!(?rejected, "check endpoint is valid, disabled cache");
return true;
}
// If the limiter allows, we don't need to check the cache.
if self.limiter.lock().await.check() {
return true;
}
let rejected = self.should_reject(endpoint);
info!(?rejected, "check endpoint is valid, enabled cache");
ctx.set_rejected(rejected);
!rejected
}
fn should_reject(&self, endpoint: &EndpointId) -> bool {
if endpoint.is_endpoint() {
!self.endpoints.contains(&EndpointIdInt::from(endpoint))
} else if endpoint.is_branch() {
!self
.branches
.contains(&BranchIdInt::from(&endpoint.as_branch()))
} else {
!self
.projects
.contains(&ProjectIdInt::from(&endpoint.as_project()))
}
}
fn insert_event(&self, key: ControlPlaneEventKey) {
// Do not do normalization here, we expect the events to be normalized.
if let Some(endpoint_created) = key.endpoint_created {
self.endpoints
.insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
}
if let Some(branch_created) = key.branch_created {
self.branches
.insert(BranchIdInt::from(&branch_created.branch_id.into()));
}
if let Some(project_created) = key.project_created {
self.projects
.insert(ProjectIdInt::from(&project_created.project_id.into()));
}
}
pub async fn do_read(
&self,
mut con: ConnectionWithCredentialsProvider,
) -> anyhow::Result<Infallible> {
let mut last_id = "0-0".to_string();
loop {
self.ready.store(false, Ordering::Release);
if let Err(e) = con.connect().await {
tracing::error!("error connecting to redis: {:?}", e);
continue;
}
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
tracing::error!("error reading from redis: {:?}", e);
}
tokio::time::sleep(self.config.retry_interval).await;
}
}
async fn read_from_stream(
&self,
con: &mut ConnectionWithCredentialsProvider,
last_id: &mut String,
) -> anyhow::Result<()> {
tracing::info!("reading endpoints/branches/projects from redis");
self.batch_read(
con,
StreamReadOptions::default().count(self.config.initial_batch_size),
last_id,
true,
)
.await?;
tracing::info!("ready to filter user requests");
self.ready.store(true, Ordering::Release);
self.batch_read(
con,
StreamReadOptions::default()
.count(self.config.default_batch_size)
.block(self.config.xread_timeout.as_millis() as usize),
last_id,
false,
)
.await
}
fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
let s: String = FromRedisValue::from_redis_value(value)?;
Ok(serde_json::from_str(&s)?)
}
async fn batch_read(
&self,
conn: &mut ConnectionWithCredentialsProvider,
opts: StreamReadOptions,
last_id: &mut String,
return_when_finish: bool,
) -> anyhow::Result<()> {
let mut total: usize = 0;
loop {
let mut res: StreamReadReply = conn
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
.await?;
if res.keys.is_empty() {
if return_when_finish {
if total != 0 {
break;
}
anyhow::bail!(
"Redis stream {} is empty, cannot be used to filter endpoints",
self.config.stream_name
);
}
// If we are not returning when finish, we should wait for more data.
continue;
}
if res.keys.len() != 1 {
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
}
let res = res.keys.pop().expect("Checked length above");
let len = res.ids.len();
for x in res.ids {
total += 1;
for (_, v) in x.map {
let key = match Self::parse_key_value(&v) {
Ok(x) => x,
Err(e) => {
Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
channel: &self.config.stream_name,
});
tracing::error!("error parsing value {v:?}: {e:?}");
continue;
}
};
self.insert_event(key);
}
if total.is_power_of_two() {
tracing::debug!("endpoints read {}", total);
}
*last_id = x.id;
}
if return_when_finish && len <= self.config.default_batch_size {
break;
}
}
tracing::info!("read {} endpoints/branches/projects from redis", total);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::ControlPlaneEventKey;
#[test]
fn test() {
let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
let _: ControlPlaneEventKey = serde_json::from_str(s).unwrap();
}
}

View File

@@ -10,7 +10,7 @@ use uuid::Uuid;
use crate::{
error::ReportableError,
metrics::{CancellationRequest, CancellationSource, Metrics},
metrics::NUM_CANCELLATION_REQUESTS,
redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
},
@@ -28,7 +28,7 @@ pub struct CancellationHandler<P> {
client: P,
/// This field used for the monitoring purposes.
/// Represents the source of the cancellation request.
from: CancellationSource,
from: &'static str,
}
#[derive(Debug, Error)]
@@ -89,13 +89,9 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::NotFound,
});
NUM_CANCELLATION_REQUESTS
.with_label_values(&[self.from, "not_found"])
.inc();
match self.client.try_publish(key, session_id).await {
Ok(()) => {} // do nothing
Err(e) => {
@@ -107,13 +103,9 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
return Ok(());
};
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::Found,
});
NUM_CANCELLATION_REQUESTS
.with_label_values(&[self.from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
@@ -130,7 +122,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
impl CancellationHandler<()> {
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
pub fn new(map: CancelMap, from: &'static str) -> Self {
Self {
map,
client: (),
@@ -140,7 +132,7 @@ impl CancellationHandler<()> {
}
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: &'static str) -> Self {
Self { map, client, from }
}
}
@@ -200,13 +192,15 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
mod tests {
use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS;
use super::*;
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
CancelMap::default(),
CancellationSource::FromRedis,
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
));
let session = cancellation_handler.clone().get_session();
@@ -220,7 +214,7 @@ mod tests {
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler = CancellationHandler::<()>::new(Default::default(), CancellationSource::Local);
let handler = CancellationHandler::<()>::new(Default::default(), "local");
handler
.cancel_session(
CancelKeyData {

View File

@@ -4,11 +4,12 @@ use crate::{
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::{Metrics, NumDbConnectionsGuard},
metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use metrics::IntCounterPairGuard;
use pq_proto::StartupMessageParams;
use std::{io, net::SocketAddr, time::Duration};
use thiserror::Error;
@@ -248,7 +249,7 @@ pub struct PostgresConnection {
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
_guage: NumDbConnectionsGuard<'static>,
_guage: IntCounterPairGuard,
}
impl ConnCfg {
@@ -294,7 +295,9 @@ impl ConnCfg {
params,
cancel_closure,
aux,
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
_guage: NUM_DB_CONNECTIONS_GAUGE
.with_label_values(&[ctx.protocol])
.guard(),
};
Ok(connection)

View File

@@ -1,6 +1,6 @@
use crate::{
auth::{self, backend::AuthRateLimiter},
rate_limiter::RateBucketInfo,
auth,
rate_limiter::{AuthRateLimiter, RateBucketInfo},
serverless::GlobalConnPoolOptions,
};
use anyhow::{bail, ensure, Context, Ok};
@@ -58,7 +58,6 @@ pub struct AuthenticationConfig {
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
}
impl TlsConfig {
@@ -314,80 +313,6 @@ impl CertResolver {
}
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
pub initial_batch_size: usize,
/// Batch size to receive endpoints.
pub default_batch_size: usize,
/// Timeouts for the stream read operation.
pub xread_timeout: Duration,
/// Stream name to read from.
pub stream_name: String,
/// Limiter info (to distinguish when to enable cache).
pub limiter_info: Vec<RateBucketInfo>,
/// Disable cache.
/// If true, cache is ignored, but reports all statistics.
pub disable_cache: bool,
/// Retry interval for the stream read operation.
pub retry_interval: Duration,
}
impl EndpointCacheConfig {
/// Default options for [`crate::console::provider::NodeInfoCache`].
/// Notice that by default the limiter is empty, which means that cache is disabled.
pub const CACHE_DEFAULT_OPTIONS: &'static str =
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
/// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> {
let mut initial_batch_size = None;
let mut default_batch_size = None;
let mut xread_timeout = None;
let mut stream_name = None;
let mut limiter_info = vec![];
let mut disable_cache = false;
let mut retry_interval = None;
for option in options.split(',') {
let (key, value) = option
.split_once('=')
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
"default_batch_size" => default_batch_size = Some(value.parse()?),
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
"stream_name" => stream_name = Some(value.to_string()),
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
"disable_cache" => disable_cache = value.parse()?,
"retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
RateBucketInfo::validate(&mut limiter_info)?;
Ok(Self {
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
stream_name: stream_name.context("missing `stream_name`")?,
disable_cache,
limiter_info,
retry_interval: retry_interval.context("missing `retry_interval`")?,
})
}
}
impl FromStr for EndpointCacheConfig {
type Err = anyhow::Error;
fn from_str(options: &str) -> Result<Self, Self::Err> {
let error = || format!("failed to parse endpoint cache options '{options}'");
Self::parse(options).with_context(error)
}
}
#[derive(Debug)]
pub struct MetricBackupCollectionConfig {
pub interval: Duration,

View File

@@ -1,4 +1,3 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::fmt;
@@ -103,7 +102,7 @@ pub struct MetricsAuxInfo {
pub cold_start_info: ColdStartInfo,
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum ColdStartInfo {
#[default]
@@ -111,11 +110,9 @@ pub enum ColdStartInfo {
/// Compute was already running
Warm,
#[serde(rename = "pool_hit")]
#[label(rename = "pool_hit")]
/// Compute was not running but there was an available VM
VmPoolHit,
#[serde(rename = "pool_miss")]
#[label(rename = "pool_miss")]
/// Compute was not running and there were no VMs available
VmPoolMiss,

View File

@@ -8,12 +8,11 @@ use crate::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
config::{CacheOptions, ProjectInfoCacheOptions},
context::RequestMonitoring,
intern::ProjectIdInt,
metrics::ApiLockMetrics,
scram, EndpointCacheKey,
};
use dashmap::DashMap;
@@ -417,15 +416,12 @@ pub struct ApiCaches {
pub node_info: NodeInfoCache,
/// Cache which stores project_id -> endpoint_ids mapping.
pub project_info: Arc<ProjectInfoCacheImpl>,
/// List of all valid endpoints.
pub endpoints_cache: Arc<EndpointsCache>,
}
impl ApiCaches {
pub fn new(
wake_compute_cache_config: CacheOptions,
project_info_cache_config: ProjectInfoCacheOptions,
endpoint_cache_config: EndpointCacheConfig,
) -> Self {
Self {
node_info: NodeInfoCache::new(
@@ -435,7 +431,6 @@ impl ApiCaches {
true,
),
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
}
}
}
@@ -446,8 +441,10 @@ pub struct ApiLocks {
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
permits: usize,
timeout: Duration,
epoch: std::time::Duration,
metrics: &'static ApiLockMetrics,
registered: prometheus::IntCounter,
unregistered: prometheus::IntCounter,
reclamation_lag: prometheus::Histogram,
lock_acquire_lag: prometheus::Histogram,
}
impl ApiLocks {
@@ -456,16 +453,54 @@ impl ApiLocks {
permits: usize,
shards: usize,
timeout: Duration,
epoch: std::time::Duration,
metrics: &'static ApiLockMetrics,
) -> prometheus::Result<Self> {
let registered = prometheus::IntCounter::with_opts(
prometheus::Opts::new(
"semaphores_registered",
"Number of semaphores registered in this api lock",
)
.namespace(name),
)?;
prometheus::register(Box::new(registered.clone()))?;
let unregistered = prometheus::IntCounter::with_opts(
prometheus::Opts::new(
"semaphores_unregistered",
"Number of semaphores unregistered in this api lock",
)
.namespace(name),
)?;
prometheus::register(Box::new(unregistered.clone()))?;
let reclamation_lag = prometheus::Histogram::with_opts(
prometheus::HistogramOpts::new(
"reclamation_lag_seconds",
"Time it takes to reclaim unused semaphores in the api lock",
)
.namespace(name)
// 1us -> 65ms
// benchmarks on my mac indicate it's usually in the range of 256us and 512us
.buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
)?;
prometheus::register(Box::new(reclamation_lag.clone()))?;
let lock_acquire_lag = prometheus::Histogram::with_opts(
prometheus::HistogramOpts::new(
"semaphore_acquire_seconds",
"Time it takes to reclaim unused semaphores in the api lock",
)
.namespace(name)
// 0.1ms -> 6s
.buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
)?;
prometheus::register(Box::new(lock_acquire_lag.clone()))?;
Ok(Self {
name,
node_locks: DashMap::with_shard_amount(shards),
permits,
timeout,
epoch,
metrics,
lock_acquire_lag,
registered,
unregistered,
reclamation_lag,
})
}
@@ -485,7 +520,7 @@ impl ApiLocks {
self.node_locks
.entry(key.clone())
.or_insert_with(|| {
self.metrics.semaphores_registered.inc();
self.registered.inc();
Arc::new(Semaphore::new(self.permits))
})
.clone()
@@ -493,21 +528,20 @@ impl ApiLocks {
};
let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
self.metrics
.semaphore_acquire_seconds
.observe(now.elapsed().as_secs_f64());
self.lock_acquire_lag
.observe((Instant::now() - now).as_secs_f64());
Ok(WakeComputePermit {
permit: Some(permit??),
})
}
pub async fn garbage_collect_worker(&self) {
pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
if self.permits == 0 {
return;
}
let mut interval =
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
loop {
for (i, shard) in self.node_locks.shards().iter().enumerate() {
interval.tick().await;
@@ -520,13 +554,13 @@ impl ApiLocks {
"performing epoch reclamation on api lock"
);
let mut lock = shard.write();
let timer = self.metrics.reclamation_lag_seconds.start_timer();
let timer = self.reclamation_lag.start_timer();
let count = lock
.extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
.count();
drop(lock);
self.metrics.semaphores_unregistered.inc_by(count as u64);
timer.observe();
self.unregistered.inc_by(count as u64);
timer.observe_duration()
}
}
}

View File

@@ -7,14 +7,13 @@ use super::{
NodeInfo,
};
use crate::{
auth::backend::ComputeUserInfo,
compute,
console::messages::ColdStartInfo,
http,
metrics::{CacheOutcome, Metrics},
scram, Normalize,
auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram,
};
use crate::{
cache::Cached,
context::RequestMonitoring,
metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
};
use crate::{cache::Cached, context::RequestMonitoring};
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::time::Instant;
@@ -24,7 +23,7 @@ use tracing::{error, info, info_span, warn, Instrument};
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub locks: &'static ApiLocks,
locks: &'static ApiLocks,
jwt: String,
}
@@ -56,15 +55,6 @@ impl Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &user_info.endpoint.normalize())
.await
{
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
let request_id = ctx.session_id.to_string();
let application_name = ctx.console_application_name();
async {
@@ -91,9 +81,7 @@ impl Api {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(http::StatusCode::NOT_FOUND) => {
return Ok(AuthInfo::default());
}
Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
_otherwise => return Err(e.into()),
},
};
@@ -107,10 +95,7 @@ impl Api {
Some(secret)
};
let allowed_ips = body.allowed_ips.unwrap_or_default();
Metrics::get()
.proxy
.allowed_ips_number
.observe(allowed_ips.len() as f64);
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
Ok(AuthInfo {
secret,
allowed_ips,
@@ -189,27 +174,23 @@ impl super::Api for Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
let ep = &user_info.endpoint;
let user = &user_info.user;
if let Some(role_secret) = self
.caches
.project_info
.get_role_secret(normalized_ep, user)
{
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
return Ok(role_secret);
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
let ep_int = ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
ep_int,
Arc::new(auth_info.allowed_ips),
);
ctx.set_project_id(project_id);
@@ -223,34 +204,30 @@ impl super::Api for Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Hit);
let ep = &user_info.endpoint;
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["hit"])
.inc();
return Ok((allowed_ips, None));
}
Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Miss);
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["miss"])
.inc();
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
let ep_int = ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches
.project_info
.insert_allowed_ips(project_id, ep_int, allowed_ips.clone());
ctx.set_project_id(project_id);
}
Ok((

View File

@@ -5,14 +5,14 @@ use once_cell::sync::OnceCell;
use smol_str::SmolStr;
use std::net::IpAddr;
use tokio::sync::mpsc;
use tracing::{field::display, info, info_span, Span};
use tracing::{field::display, info_span, Span};
use uuid::Uuid;
use crate::{
console::messages::{ColdStartInfo, MetricsAuxInfo},
error::ErrorKind,
intern::{BranchIdInt, ProjectIdInt},
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol},
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
DbName, EndpointId, RoleName,
};
@@ -29,7 +29,7 @@ static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::ne
pub struct RequestMonitoring {
pub peer_addr: IpAddr,
pub session_id: Uuid,
pub protocol: Protocol,
pub protocol: &'static str,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub span: Span,
@@ -50,8 +50,6 @@ pub struct RequestMonitoring {
// This sender is here to keep the request monitoring channel open while requests are taking place.
sender: Option<mpsc::UnboundedSender<RequestData>>,
pub latency_timer: LatencyTimer,
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
rejected: bool,
}
#[derive(Clone, Debug)]
@@ -67,7 +65,7 @@ impl RequestMonitoring {
pub fn new(
session_id: Uuid,
peer_addr: IpAddr,
protocol: Protocol,
protocol: &'static str,
region: &'static str,
) -> Self {
let span = info_span!(
@@ -76,7 +74,6 @@ impl RequestMonitoring {
?session_id,
%peer_addr,
ep = tracing::field::Empty,
role = tracing::field::Empty,
);
Self {
@@ -96,7 +93,6 @@ impl RequestMonitoring {
error_kind: None,
auth_method: None,
success: false,
rejected: false,
cold_start_info: ColdStartInfo::Unknown,
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
@@ -106,7 +102,7 @@ impl RequestMonitoring {
#[cfg(test)]
pub fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
}
pub fn console_application_name(&self) -> String {
@@ -117,10 +113,6 @@ impl RequestMonitoring {
)
}
pub fn set_rejected(&mut self, rejected: bool) {
self.rejected = rejected;
}
pub fn set_cold_start_info(&mut self, info: ColdStartInfo) {
self.cold_start_info = info;
self.latency_timer.cold_start_info(info);
@@ -142,9 +134,9 @@ impl RequestMonitoring {
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
if self.endpoint_id.is_none() {
self.span.record("ep", display(&endpoint_id));
let metric = &Metrics::get().proxy.connecting_endpoints;
let label = metric.with_labels(self.protocol);
metric.get_metric(label).measure(&endpoint_id);
crate::metrics::CONNECTING_ENDPOINTS
.with_label_values(&[self.protocol])
.measure(&endpoint_id);
self.endpoint_id = Some(endpoint_id);
}
}
@@ -158,7 +150,6 @@ impl RequestMonitoring {
}
pub fn set_user(&mut self, user: RoleName) {
self.span.record("role", display(&user));
self.user = Some(user);
}
@@ -166,22 +157,14 @@ impl RequestMonitoring {
self.auth_method = Some(auth_method);
}
pub fn has_private_peer_addr(&self) -> bool {
match self.peer_addr {
IpAddr::V4(ip) => ip.is_private(),
_ => false,
}
}
pub fn set_error_kind(&mut self, kind: ErrorKind) {
// Do not record errors from the private address to metrics.
if !self.has_private_peer_addr() {
Metrics::get().proxy.errors_total.inc(kind);
}
ERROR_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.inc();
if let Some(ep) = &self.endpoint_id {
let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
let label = metric.with_labels(kind);
metric.get_metric(label).measure(ep);
ENDPOINT_ERRORS_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.measure(ep);
}
self.error_kind = Some(kind);
}
@@ -195,32 +178,6 @@ impl RequestMonitoring {
impl Drop for RequestMonitoring {
fn drop(&mut self) {
let outcome = if self.success {
ConnectOutcome::Success
} else {
ConnectOutcome::Failed
};
let rejected = self.rejected;
let ep = self
.endpoint_id
.as_ref()
.map(|x| x.as_str())
.unwrap_or_default();
// This makes sense only if cache is disabled
info!(
?ep,
?outcome,
?rejected,
"check endpoint is valid with outcome"
);
Metrics::get()
.proxy
.invalid_endpoints_total
.inc(InvalidEndpointsGroup {
protocol: self.protocol,
rejected: rejected.into(),
outcome,
});
if let Some(tx) = self.sender.take() {
let _: Result<(), _> = tx.send(RequestData::from(&*self));
}

View File

@@ -111,7 +111,7 @@ impl From<&RequestMonitoring> for RequestData {
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
super::AuthMethod::Cleartext => "cleartext",
}),
protocol: value.protocol.as_str(),
protocol: value.protocol,
region: value.region,
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,

View File

@@ -1,7 +1,5 @@
use std::{error::Error as StdError, fmt, io};
use measured::FixedCardinalityLabel;
/// Upcast (almost) any error into an opaque [`io::Error`].
pub fn io_error(e: impl Into<Box<dyn StdError + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
@@ -31,29 +29,24 @@ pub trait UserFacingError: ReportableError {
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, FixedCardinalityLabel)]
#[label(singleton = "type")]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
/// Network error between user and proxy. Not necessarily user error
#[label(rename = "clientdisconnect")]
ClientDisconnect,
/// Proxy self-imposed user rate limits
#[label(rename = "ratelimit")]
RateLimit,
/// Proxy self-imposed service-wise rate limits
#[label(rename = "serviceratelimit")]
ServiceRateLimit,
/// internal errors
Service,
/// Error communicating with control plane
#[label(rename = "controlplane")]
ControlPlane,
/// Postgres error

View File

@@ -13,16 +13,13 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time::Instant;
use tracing::trace;
use crate::{
metrics::{ConsoleRequest, Metrics},
url::ApiUrl,
};
use crate::{metrics::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl};
use reqwest_middleware::RequestBuilder;
/// This is the preferred way to create new http clients,
/// because it takes care of observability (OpenTelemetry).
/// We deliberately don't want to replace this with a public static.
pub fn new_client() -> ClientWithMiddleware {
pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware {
let client = reqwest::ClientBuilder::new()
.dns_resolver(Arc::new(GaiResolver::default()))
.connection_verbose(true)
@@ -31,6 +28,7 @@ pub fn new_client() -> ClientWithMiddleware {
reqwest_middleware::ClientBuilder::new(client)
.with(reqwest_tracing::TracingMiddleware::default())
.with(rate_limiter::Limiter::new(rate_limiter_config))
.build()
}
@@ -92,14 +90,13 @@ impl Endpoint {
/// Execute a [request](reqwest::Request).
pub async fn execute(&self, request: Request) -> Result<Response, Error> {
let _timer = Metrics::get()
.proxy
.console_request_latency
.start_timer(ConsoleRequest {
request: request.url().path(),
});
self.client.execute(request).await
let path = request.url().path().to_string();
let start = Instant::now();
let res = self.client.execute(request).await;
CONSOLE_REQUEST_LATENCY
.with_label_values(&[&path])
.observe(start.elapsed().as_secs_f64());
res
}
}

View File

@@ -1,49 +1,30 @@
use anyhow::{anyhow, bail};
use hyper::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
use measured::{text::BufferedTextEncoder, MetricGroup};
use metrics::NeonMetrics;
use std::{
convert::Infallible,
net::TcpListener,
sync::{Arc, Mutex},
};
use tracing::{info, info_span};
use hyper::{Body, Request, Response, StatusCode};
use std::{convert::Infallible, net::TcpListener};
use tracing::info;
use utils::http::{
endpoint::{self, request_span},
endpoint::{self, prometheus_metrics_handler, request_span},
error::ApiError,
json::json_response,
RouterBuilder, RouterService,
};
use crate::jemalloc;
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "")
}
fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper::Body, ApiError> {
let state = Arc::new(Mutex::new(PrometheusHandler {
encoder: BufferedTextEncoder::new(),
metrics,
}));
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
endpoint::make_router()
.get("/metrics", move |r| {
let state = state.clone();
request_span(r, move |b| prometheus_metrics_handler(b, state))
})
.get("/metrics", |r| request_span(r, prometheus_metrics_handler))
.get("/v1/status", status_handler)
}
pub async fn task_main(
http_listener: TcpListener,
metrics: AppMetrics,
) -> anyhow::Result<Infallible> {
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<Infallible> {
scopeguard::defer! {
info!("http has shut down");
}
let service = || RouterService::new(make_router(metrics).build()?);
let service = || RouterService::new(make_router().build()?);
hyper::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
@@ -51,57 +32,3 @@ pub async fn task_main(
bail!("hyper server without shutdown handling cannot shutdown successfully");
}
struct PrometheusHandler {
encoder: BufferedTextEncoder,
metrics: AppMetrics,
}
#[derive(MetricGroup)]
pub struct AppMetrics {
#[metric(namespace = "jemalloc")]
pub jemalloc: Option<jemalloc::MetricRecorder>,
#[metric(flatten)]
pub neon_metrics: NeonMetrics,
#[metric(flatten)]
pub proxy: &'static crate::metrics::Metrics,
}
async fn prometheus_metrics_handler(
_req: Request<Body>,
state: Arc<Mutex<PrometheusHandler>>,
) -> Result<Response<Body>, ApiError> {
let started_at = std::time::Instant::now();
let span = info_span!("blocking");
let body = tokio::task::spawn_blocking(move || {
let _span = span.entered();
let mut state = state.lock().unwrap();
let PrometheusHandler { encoder, metrics } = &mut *state;
metrics
.collect_group_into(&mut *encoder)
.unwrap_or_else(|infallible| match infallible {});
let body = encoder.finish();
tracing::info!(
bytes = body.len(),
elapsed_ms = started_at.elapsed().as_millis(),
"responded /metrics"
);
body
})
.await
.unwrap();
let response = Response::builder()
.status(200)
.header(CONTENT_TYPE, "text/plain; version=0.0.4")
.body(Body::from(body))
.unwrap();
Ok(response)
}

View File

@@ -160,11 +160,6 @@ impl From<&EndpointId> for EndpointIdInt {
EndpointIdTag::get_interner().get_or_intern(value)
}
}
impl From<EndpointId> for EndpointIdInt {
fn from(value: EndpointId) -> Self {
EndpointIdTag::get_interner().get_or_intern(&value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct BranchIdTag;
@@ -180,11 +175,6 @@ impl From<&BranchId> for BranchIdInt {
BranchIdTag::get_interner().get_or_intern(value)
}
}
impl From<BranchId> for BranchIdInt {
fn from(value: BranchId) -> Self {
BranchIdTag::get_interner().get_or_intern(&value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct ProjectIdTag;
@@ -200,11 +190,6 @@ impl From<&ProjectId> for ProjectIdInt {
ProjectIdTag::get_interner().get_or_intern(value)
}
}
impl From<ProjectId> for ProjectIdInt {
fn from(value: ProjectId) -> Self {
ProjectIdTag::get_interner().get_or_intern(&value)
}
}
#[cfg(test)]
mod tests {

View File

@@ -1,45 +1,27 @@
use std::marker::PhantomData;
use std::time::Duration;
use measured::{
label::NoLabels,
metric::{
gauge::GaugeState, group::Encoding, group::MetricValue, name::MetricNameEncoder,
MetricEncoding, MetricFamilyEncoding, MetricType,
},
text::TextEncoder,
LabelGroup, MetricGroup,
};
use metrics::IntGauge;
use prometheus::{register_int_gauge_with_registry, Registry};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version};
pub struct MetricRecorder {
epoch: epoch_mib,
inner: Metrics,
}
#[derive(MetricGroup)]
struct Metrics {
active_bytes: JemallocGaugeFamily<stats::active_mib>,
allocated_bytes: JemallocGaugeFamily<stats::allocated_mib>,
mapped_bytes: JemallocGaugeFamily<stats::mapped_mib>,
metadata_bytes: JemallocGaugeFamily<stats::metadata_mib>,
resident_bytes: JemallocGaugeFamily<stats::resident_mib>,
retained_bytes: JemallocGaugeFamily<stats::retained_mib>,
}
impl<Enc: Encoding> MetricGroup<Enc> for MetricRecorder
where
Metrics: MetricGroup<Enc>,
{
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
if self.epoch.advance().is_ok() {
self.inner.collect_group_into(enc)?;
}
Ok(())
}
active: stats::active_mib,
active_gauge: IntGauge,
allocated: stats::allocated_mib,
allocated_gauge: IntGauge,
mapped: stats::mapped_mib,
mapped_gauge: IntGauge,
metadata: stats::metadata_mib,
metadata_gauge: IntGauge,
resident: stats::resident_mib,
resident_gauge: IntGauge,
retained: stats::retained_mib,
retained_gauge: IntGauge,
}
impl MetricRecorder {
pub fn new() -> Result<Self, anyhow::Error> {
pub fn new(registry: &Registry) -> Result<Self, anyhow::Error> {
tracing::info!(
config = config::malloc_conf::read()?,
version = version::read()?,
@@ -48,69 +30,71 @@ impl MetricRecorder {
Ok(Self {
epoch: epoch::mib()?,
inner: Metrics {
active_bytes: JemallocGaugeFamily(stats::active::mib()?),
allocated_bytes: JemallocGaugeFamily(stats::allocated::mib()?),
mapped_bytes: JemallocGaugeFamily(stats::mapped::mib()?),
metadata_bytes: JemallocGaugeFamily(stats::metadata::mib()?),
resident_bytes: JemallocGaugeFamily(stats::resident::mib()?),
retained_bytes: JemallocGaugeFamily(stats::retained::mib()?),
},
active: stats::active::mib()?,
active_gauge: register_int_gauge_with_registry!(
"jemalloc_active_bytes",
"Total number of bytes in active pages allocated by the process",
registry
)?,
allocated: stats::allocated::mib()?,
allocated_gauge: register_int_gauge_with_registry!(
"jemalloc_allocated_bytes",
"Total number of bytes allocated by the process",
registry
)?,
mapped: stats::mapped::mib()?,
mapped_gauge: register_int_gauge_with_registry!(
"jemalloc_mapped_bytes",
"Total number of bytes in active extents mapped by the allocator",
registry
)?,
metadata: stats::metadata::mib()?,
metadata_gauge: register_int_gauge_with_registry!(
"jemalloc_metadata_bytes",
"Total number of bytes dedicated to jemalloc metadata",
registry
)?,
resident: stats::resident::mib()?,
resident_gauge: register_int_gauge_with_registry!(
"jemalloc_resident_bytes",
"Total number of bytes in physically resident data pages mapped by the allocator",
registry
)?,
retained: stats::retained::mib()?,
retained_gauge: register_int_gauge_with_registry!(
"jemalloc_retained_bytes",
"Total number of bytes in virtual memory mappings that were retained rather than being returned to the operating system",
registry
)?,
})
}
fn _poll(&self) -> Result<(), anyhow::Error> {
self.epoch.advance()?;
self.active_gauge.set(self.active.read()? as i64);
self.allocated_gauge.set(self.allocated.read()? as i64);
self.mapped_gauge.set(self.mapped.read()? as i64);
self.metadata_gauge.set(self.metadata.read()? as i64);
self.resident_gauge.set(self.resident.read()? as i64);
self.retained_gauge.set(self.retained.read()? as i64);
Ok(())
}
#[inline]
pub fn poll(&self) {
if let Err(error) = self._poll() {
tracing::warn!(%error, "Failed to poll jemalloc stats");
}
}
pub fn start(self) -> tokio::task::JoinHandle<()> {
tokio::task::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(15));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
self.poll();
interval.tick().await;
}
})
}
}
struct JemallocGauge<T>(PhantomData<T>);
impl<T> Default for JemallocGauge<T> {
fn default() -> Self {
JemallocGauge(PhantomData)
}
}
impl<T> MetricType for JemallocGauge<T> {
type Metadata = T;
}
struct JemallocGaugeFamily<T>(T);
impl<M, T: Encoding> MetricFamilyEncoding<T> for JemallocGaugeFamily<M>
where
JemallocGauge<M>: MetricEncoding<T, Metadata = M>,
{
fn collect_family_into(&self, name: impl MetricNameEncoder, enc: &mut T) -> Result<(), T::Err> {
JemallocGauge::write_type(&name, enc)?;
JemallocGauge(PhantomData).collect_into(&self.0, NoLabels, name, enc)
}
}
macro_rules! jemalloc_gauge {
($stat:ident, $mib:ident) => {
impl<W: std::io::Write> MetricEncoding<TextEncoder<W>> for JemallocGauge<stats::$mib> {
fn write_type(
name: impl MetricNameEncoder,
enc: &mut TextEncoder<W>,
) -> Result<(), std::io::Error> {
GaugeState::write_type(name, enc)
}
fn collect_into(
&self,
mib: &stats::$mib,
labels: impl LabelGroup,
name: impl MetricNameEncoder,
enc: &mut TextEncoder<W>,
) -> Result<(), std::io::Error> {
if let Ok(v) = mib.read() {
enc.write_metric_value(name, labels, MetricValue::Int(v as i64))?;
}
Ok(())
}
}
};
}
jemalloc_gauge!(active, active_mib);
jemalloc_gauge!(allocated, allocated_mib);
jemalloc_gauge!(mapped, mapped_mib);
jemalloc_gauge!(metadata, metadata_mib);
jemalloc_gauge!(resident, resident_mib);
jemalloc_gauge!(retained, retained_mib);

View File

@@ -127,24 +127,6 @@ macro_rules! smol_str_wrapper {
};
}
const POOLER_SUFFIX: &str = "-pooler";
pub trait Normalize {
fn normalize(&self) -> Self;
}
impl<S: Clone + AsRef<str> + From<String>> Normalize for S {
fn normalize(&self) -> Self {
if self.as_ref().ends_with(POOLER_SUFFIX) {
let mut s = self.as_ref().to_string();
s.truncate(s.len() - POOLER_SUFFIX.len());
s.into()
} else {
self.clone()
}
}
}
// 90% of role name strings are 20 characters or less.
smol_str_wrapper!(RoleName);
// 50% of endpoint strings are 23 characters or less.
@@ -158,22 +140,3 @@ smol_str_wrapper!(ProjectId);
smol_str_wrapper!(EndpointCacheKey);
smol_str_wrapper!(DbName);
// Endpoints are a bit tricky. Rare they might be branches or projects.
impl EndpointId {
pub fn is_endpoint(&self) -> bool {
self.0.starts_with("ep-")
}
pub fn is_branch(&self) -> bool {
self.0.starts_with("br-")
}
pub fn is_project(&self) -> bool {
!self.is_endpoint() && !self.is_branch()
}
pub fn as_branch(&self) -> BranchId {
BranchId(self.0.clone())
}
pub fn as_project(&self) -> ProjectId {
ProjectId(self.0.clone())
}
}

View File

@@ -1,348 +1,176 @@
use std::sync::OnceLock;
use lasso::ThreadedRodeo;
use measured::{
label::StaticLabelSet,
metric::{histogram::Thresholds, name::MetricName},
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
use ::metrics::{
exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge,
register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec,
IntCounterVec, IntGauge, IntGaugeVec,
};
use metrics::{
register_hll, register_int_counter, register_int_counter_pair, HyperLogLog, IntCounter,
IntCounterPair,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
use once_cell::sync::Lazy;
use tokio::time::{self, Instant};
use crate::console::messages::ColdStartInfo;
#[derive(MetricGroup)]
pub struct Metrics {
#[metric(namespace = "proxy")]
pub proxy: ProxyMetrics,
pub static NUM_DB_CONNECTIONS_GAUGE: Lazy<IntCounterPairVec> = Lazy::new(|| {
register_int_counter_pair_vec!(
"proxy_opened_db_connections_total",
"Number of opened connections to a database.",
"proxy_closed_db_connections_total",
"Number of closed connections to a database.",
&["protocol"],
)
.unwrap()
});
#[metric(namespace = "wake_compute_lock")]
pub wake_compute_lock: ApiLockMetrics,
}
pub static NUM_CLIENT_CONNECTION_GAUGE: Lazy<IntCounterPairVec> = Lazy::new(|| {
register_int_counter_pair_vec!(
"proxy_opened_client_connections_total",
"Number of opened connections from a client.",
"proxy_closed_client_connections_total",
"Number of closed connections from a client.",
&["protocol"],
)
.unwrap()
});
impl Metrics {
pub fn get() -> &'static Self {
static SELF: OnceLock<Metrics> = OnceLock::new();
SELF.get_or_init(|| Metrics {
proxy: ProxyMetrics::default(),
wake_compute_lock: ApiLockMetrics::new(),
})
}
}
pub static NUM_CONNECTION_REQUESTS_GAUGE: Lazy<IntCounterPairVec> = Lazy::new(|| {
register_int_counter_pair_vec!(
"proxy_accepted_connections_total",
"Number of client connections accepted.",
"proxy_closed_connections_total",
"Number of client connections closed.",
&["protocol"],
)
.unwrap()
});
#[derive(MetricGroup)]
#[metric(new())]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
#[metric(flatten)]
pub client_connections: CounterPairVec<NumClientConnectionsGauge>,
#[metric(flatten)]
pub connection_requests: CounterPairVec<NumConnectionRequestsGauge>,
#[metric(flatten)]
pub http_endpoint_pools: HttpEndpointPools,
pub static COMPUTE_CONNECTION_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"proxy_compute_connection_latency_seconds",
"Time it took for proxy to establish a connection to the compute endpoint",
// http/ws/tcp, true/false, true/false, success/failure, client/client_and_cplane
// 3 * 6 * 2 * 2 = 72 counters
&["protocol", "cold_start_info", "outcome", "excluded"],
// largest bucket = 2^16 * 0.5ms = 32s
exponential_buckets(0.0005, 2.0, 16).unwrap(),
)
.unwrap()
});
/// Time it took for proxy to establish a connection to the compute endpoint.
// largest bucket = 2^16 * 0.5ms = 32s
#[metric(metadata = Thresholds::exponential_buckets(0.0005, 2.0))]
pub compute_connection_latency_seconds: HistogramVec<ComputeConnectionLatencySet, 16>,
/// Time it took for proxy to receive a response from control plane.
#[metric(
pub static CONSOLE_REQUEST_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"proxy_console_request_latency",
"Time it took for proxy to establish a connection to the compute endpoint",
// proxy_wake_compute/proxy_get_role_info
&["request"],
// largest bucket = 2^16 * 0.2ms = 13s
metadata = Thresholds::exponential_buckets(0.0002, 2.0),
)]
pub console_request_latency: HistogramVec<ConsoleRequestSet, 16>,
exponential_buckets(0.0002, 2.0, 16).unwrap(),
)
.unwrap()
});
/// Time it takes to acquire a token to call console plane.
// largest bucket = 3^16 * 0.05ms = 2.15s
#[metric(metadata = Thresholds::exponential_buckets(0.00005, 3.0))]
pub control_plane_token_acquire_seconds: Histogram<16>,
pub static ALLOWED_IPS_BY_CACHE_OUTCOME: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_allowed_ips_cache_misses",
"Number of cache hits/misses for allowed ips",
// hit/miss
&["outcome"],
)
.unwrap()
});
/// Size of the HTTP request body lengths.
// smallest bucket = 16 bytes
// largest bucket = 4^12 * 16 bytes = 256MB
#[metric(metadata = Thresholds::exponential_buckets(16.0, 4.0))]
pub http_conn_content_length_bytes: HistogramVec<StaticLabelSet<HttpDirection>, 12>,
pub static RATE_LIMITER_ACQUIRE_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_control_plane_token_acquire_seconds",
"Time it took for proxy to establish a connection to the compute endpoint",
// largest bucket = 3^16 * 0.05ms = 2.15s
exponential_buckets(0.00005, 3.0, 16).unwrap(),
)
.unwrap()
});
/// Time it takes to reclaim unused connection pools.
#[metric(metadata = Thresholds::exponential_buckets(1e-6, 2.0))]
pub http_pool_reclaimation_lag_seconds: Histogram<16>,
pub static RATE_LIMITER_LIMIT: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!(
"semaphore_control_plane_limit",
"Current limit of the semaphore control plane",
&["limit"], // 2 counters
)
.unwrap()
});
/// Number of opened connections to a database.
pub http_pool_opened_connections: Gauge,
pub static NUM_CONNECTION_ACCEPTED_BY_SNI: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_accepted_connections_by_sni",
"Number of connections (per sni).",
&["kind"],
)
.unwrap()
});
/// Number of cache hits/misses for allowed ips.
pub allowed_ips_cache_misses: CounterVec<StaticLabelSet<CacheOutcome>>,
pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_allowed_ips_number",
"Number of allowed ips",
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0],
)
.unwrap()
});
/// Number of allowed ips
#[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))]
pub allowed_ips_number: Histogram<10>,
pub static HTTP_CONTENT_LENGTH: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"proxy_http_conn_content_length_bytes",
"Number of bytes the HTTP response content consumes",
// request/response
&["direction"],
// smallest bucket = 16 bytes
// largest bucket = 4^12 * 16 bytes = 256MB
exponential_buckets(16.0, 4.0, 12).unwrap()
)
.unwrap()
});
/// Number of connections (per sni).
pub accepted_connections_by_sni: CounterVec<StaticLabelSet<SniKind>>,
pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_http_pool_reclaimation_lag_seconds",
"Time it takes to reclaim unused connection pools",
// 1us -> 65ms
exponential_buckets(1e-6, 2.0, 16).unwrap(),
)
.unwrap()
});
/// Number of connection failures (per kind).
pub connection_failures_total: CounterVec<StaticLabelSet<ConnectionFailureKind>>,
pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
register_int_counter_pair!(
"proxy_http_pool_endpoints_registered_total",
"Number of endpoints we have registered pools for",
"proxy_http_pool_endpoints_unregistered_total",
"Number of endpoints we have unregistered pools for",
)
.unwrap()
});
/// Number of wake-up failures (per kind).
pub connection_failures_breakdown: CounterVec<ConnectionFailuresBreakdownSet>,
pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"proxy_http_pool_opened_connections",
"Number of opened connections to a database.",
)
.unwrap()
});
/// Number of bytes sent/received between all clients and backends.
pub io_bytes: CounterVec<StaticLabelSet<Direction>>,
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_cancellation_requests_total",
"Number of cancellation requests (per found/not_found).",
&["source", "kind"],
)
.unwrap()
});
/// Number of errors by a given classification.
pub errors_total: CounterVec<StaticLabelSet<crate::error::ErrorKind>>,
/// Number of cancellation requests (per found/not_found).
pub cancellation_requests_total: CounterVec<CancellationRequestSet>,
/// Number of errors by a given classification
pub redis_errors_total: CounterVec<RedisErrorsSet>,
/// Number of TLS handshake failures
pub tls_handshake_failures: Counter,
/// Number of connection requests affected by authentication rate limits
pub requests_auth_rate_limits_total: Counter,
/// HLL approximate cardinality of endpoints that are connecting
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
/// Number of endpoints affected by errors of a given classification
pub endpoints_affected_by_errors: HyperLogLogVec<StaticLabelSet<crate::error::ErrorKind>, 32>,
/// Number of endpoints affected by authentication rate limits
pub endpoints_auth_rate_limits: HyperLogLog<32>,
/// Number of invalid endpoints (per protocol, per rejected).
pub invalid_endpoints_total: CounterVec<InvalidEndpointsSet>,
}
#[derive(MetricGroup)]
#[metric(new())]
pub struct ApiLockMetrics {
/// Number of semaphores registered in this api lock
pub semaphores_registered: Counter,
/// Number of semaphores unregistered in this api lock
pub semaphores_unregistered: Counter,
/// Time it takes to reclaim unused semaphores in the api lock
#[metric(metadata = Thresholds::exponential_buckets(1e-6, 2.0))]
pub reclamation_lag_seconds: Histogram<16>,
/// Time it takes to acquire a semaphore lock
#[metric(metadata = Thresholds::exponential_buckets(1e-4, 2.0))]
pub semaphore_acquire_seconds: Histogram<16>,
}
impl Default for ProxyMetrics {
fn default() -> Self {
Self::new()
}
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "direction")]
pub enum HttpDirection {
Request,
Response,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "direction")]
pub enum Direction {
Tx,
Rx,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
#[label(singleton = "protocol")]
pub enum Protocol {
Http,
Ws,
Tcp,
SniRouter,
}
impl Protocol {
pub fn as_str(&self) -> &'static str {
match self {
Protocol::Http => "http",
Protocol::Ws => "ws",
Protocol::Tcp => "tcp",
Protocol::SniRouter => "sni_router",
}
}
}
impl std::fmt::Display for Protocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
pub enum Bool {
True,
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum Outcome {
Success,
Failed,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {
Hit,
Miss,
}
#[derive(LabelGroup)]
#[label(set = ConsoleRequestSet)]
pub struct ConsoleRequest<'a> {
#[label(dynamic_with = ThreadedRodeo, default)]
pub request: &'a str,
}
#[derive(MetricGroup, Default)]
pub struct HttpEndpointPools {
/// Number of endpoints we have registered pools for
pub http_pool_endpoints_registered_total: Counter,
/// Number of endpoints we have unregistered pools for
pub http_pool_endpoints_unregistered_total: Counter,
}
pub struct HttpEndpointPoolsGuard<'a> {
dec: &'a Counter,
}
impl Drop for HttpEndpointPoolsGuard<'_> {
fn drop(&mut self) {
self.dec.inc();
}
}
impl HttpEndpointPools {
pub fn guard(&self) -> HttpEndpointPoolsGuard {
self.http_pool_endpoints_registered_total.inc();
HttpEndpointPoolsGuard {
dec: &self.http_pool_endpoints_unregistered_total,
}
}
}
pub struct NumDbConnectionsGauge;
impl CounterPairAssoc for NumDbConnectionsGauge {
const INC_NAME: &'static MetricName = MetricName::from_str("opened_db_connections_total");
const DEC_NAME: &'static MetricName = MetricName::from_str("closed_db_connections_total");
const INC_HELP: &'static str = "Number of opened connections to a database.";
const DEC_HELP: &'static str = "Number of closed connections to a database.";
type LabelGroupSet = StaticLabelSet<Protocol>;
}
pub type NumDbConnectionsGuard<'a> = metrics::MeasuredCounterPairGuard<'a, NumDbConnectionsGauge>;
pub struct NumClientConnectionsGauge;
impl CounterPairAssoc for NumClientConnectionsGauge {
const INC_NAME: &'static MetricName = MetricName::from_str("opened_client_connections_total");
const DEC_NAME: &'static MetricName = MetricName::from_str("closed_client_connections_total");
const INC_HELP: &'static str = "Number of opened connections from a client.";
const DEC_HELP: &'static str = "Number of closed connections from a client.";
type LabelGroupSet = StaticLabelSet<Protocol>;
}
pub type NumClientConnectionsGuard<'a> =
metrics::MeasuredCounterPairGuard<'a, NumClientConnectionsGauge>;
pub struct NumConnectionRequestsGauge;
impl CounterPairAssoc for NumConnectionRequestsGauge {
const INC_NAME: &'static MetricName = MetricName::from_str("accepted_connections_total");
const DEC_NAME: &'static MetricName = MetricName::from_str("closed_connections_total");
const INC_HELP: &'static str = "Number of client connections accepted.";
const DEC_HELP: &'static str = "Number of client connections closed.";
type LabelGroupSet = StaticLabelSet<Protocol>;
}
pub type NumConnectionRequestsGuard<'a> =
metrics::MeasuredCounterPairGuard<'a, NumConnectionRequestsGauge>;
#[derive(LabelGroup)]
#[label(set = ComputeConnectionLatencySet)]
pub struct ComputeConnectionLatencyGroup {
protocol: Protocol,
cold_start_info: ColdStartInfo,
outcome: ConnectOutcome,
excluded: LatencyExclusions,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
pub enum LatencyExclusions {
Client,
ClientAndCplane,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "kind")]
pub enum SniKind {
Sni,
NoSni,
PasswordHack,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "kind")]
pub enum ConnectionFailureKind {
ComputeCached,
ComputeUncached,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "kind")]
pub enum WakeupFailureKind {
BadComputeAddress,
ApiTransportError,
QuotaExceeded,
ApiConsoleLocked,
ApiConsoleBadRequest,
ApiConsoleOtherServerError,
ApiConsoleOtherError,
TimeoutError,
}
#[derive(LabelGroup)]
#[label(set = ConnectionFailuresBreakdownSet)]
pub struct ConnectionFailuresBreakdownGroup {
pub kind: WakeupFailureKind,
pub retry: Bool,
}
#[derive(LabelGroup, Copy, Clone)]
#[label(set = RedisErrorsSet)]
pub struct RedisErrors<'a> {
#[label(dynamic_with = ThreadedRodeo, default)]
pub channel: &'a str,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
pub enum CancellationSource {
FromClient,
FromRedis,
Local,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
pub enum CancellationOutcome {
NotFound,
Found,
}
#[derive(LabelGroup)]
#[label(set = CancellationRequestSet)]
pub struct CancellationRequest {
pub source: CancellationSource,
pub kind: CancellationOutcome,
}
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client";
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis";
pub enum Waiting {
Cplane,
@@ -357,6 +185,20 @@ struct Accumulated {
compute: time::Duration,
}
enum Outcome {
Success,
Failed,
}
impl Outcome {
fn as_str(&self) -> &'static str {
match self {
Outcome::Success => "success",
Outcome::Failed => "failed",
}
}
}
pub struct LatencyTimer {
// time since the stopwatch was started
start: time::Instant,
@@ -365,9 +207,9 @@ pub struct LatencyTimer {
// accumulated time on the stopwatch
accumulated: Accumulated,
// label data
protocol: Protocol,
protocol: &'static str,
cold_start_info: ColdStartInfo,
outcome: ConnectOutcome,
outcome: Outcome,
}
pub struct LatencyTimerPause<'a> {
@@ -377,7 +219,7 @@ pub struct LatencyTimerPause<'a> {
}
impl LatencyTimer {
pub fn new(protocol: Protocol) -> Self {
pub fn new(protocol: &'static str) -> Self {
Self {
start: time::Instant::now(),
stop: None,
@@ -385,7 +227,7 @@ impl LatencyTimer {
protocol,
cold_start_info: ColdStartInfo::Unknown,
// assume failed unless otherwise specified
outcome: ConnectOutcome::Failed,
outcome: Outcome::Failed,
}
}
@@ -406,7 +248,7 @@ impl LatencyTimer {
self.stop = Some(time::Instant::now());
// success
self.outcome = ConnectOutcome::Success;
self.outcome = Outcome::Success;
}
}
@@ -421,62 +263,128 @@ impl Drop for LatencyTimerPause<'_> {
}
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
pub enum ConnectOutcome {
Success,
Failed,
}
impl Drop for LatencyTimer {
fn drop(&mut self) {
let duration = self
.stop
.unwrap_or_else(time::Instant::now)
.duration_since(self.start);
let metric = &Metrics::get().proxy.compute_connection_latency_seconds;
// Excluding client communication from the accumulated time.
metric.observe(
ComputeConnectionLatencyGroup {
protocol: self.protocol,
cold_start_info: self.cold_start_info,
outcome: self.outcome,
excluded: LatencyExclusions::Client,
},
duration
.saturating_sub(self.accumulated.client)
.as_secs_f64(),
);
// Excluding cplane communication from the accumulated time.
COMPUTE_CONNECTION_LATENCY
.with_label_values(&[
self.protocol,
self.cold_start_info.as_str(),
self.outcome.as_str(),
"client",
])
.observe((duration.saturating_sub(self.accumulated.client)).as_secs_f64());
// Exclude client and cplane communication from the accumulated time.
let accumulated_total = self.accumulated.client + self.accumulated.cplane;
metric.observe(
ComputeConnectionLatencyGroup {
protocol: self.protocol,
cold_start_info: self.cold_start_info,
outcome: self.outcome,
excluded: LatencyExclusions::ClientAndCplane,
},
duration.saturating_sub(accumulated_total).as_secs_f64(),
);
COMPUTE_CONNECTION_LATENCY
.with_label_values(&[
self.protocol,
self.cold_start_info.as_str(),
self.outcome.as_str(),
"client_and_cplane",
])
.observe((duration.saturating_sub(accumulated_total)).as_secs_f64());
}
}
impl From<bool> for Bool {
fn from(value: bool) -> Self {
if value {
Bool::True
} else {
Bool::False
}
pub static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_connection_failures_total",
"Number of connection failures (per kind).",
&["kind"],
)
.unwrap()
});
pub static NUM_WAKEUP_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_connection_failures_breakdown",
"Number of wake-up failures (per kind).",
&["retry", "kind"],
)
.unwrap()
});
pub static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_io_bytes",
"Number of bytes sent/received between all clients and backends.",
&["direction"],
)
.unwrap()
});
pub const fn bool_to_str(x: bool) -> &'static str {
if x {
"true"
} else {
"false"
}
}
#[derive(LabelGroup)]
#[label(set = InvalidEndpointsSet)]
pub struct InvalidEndpointsGroup {
pub protocol: Protocol,
pub rejected: Bool,
pub outcome: ConnectOutcome,
}
pub static CONNECTING_ENDPOINTS: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
register_hll_vec!(
32,
"proxy_connecting_endpoints",
"HLL approximate cardinality of endpoints that are connecting",
&["protocol"],
)
.unwrap()
});
pub static ERROR_BY_KIND: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_errors_total",
"Number of errors by a given classification",
&["type"],
)
.unwrap()
});
pub static ENDPOINT_ERRORS_BY_KIND: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
register_hll_vec!(
32,
"proxy_endpoints_affected_by_errors",
"Number of endpoints affected by errors of a given classification",
&["type"],
)
.unwrap()
});
pub static REDIS_BROKEN_MESSAGES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_redis_errors_total",
"Number of errors by a given classification",
&["channel"],
)
.unwrap()
});
pub static TLS_HANDSHAKE_FAILURES: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"proxy_tls_handshake_failures",
"Number of TLS handshake failures",
)
.unwrap()
});
pub static ENDPOINTS_AUTH_RATE_LIMITED: Lazy<HyperLogLog<32>> = Lazy::new(|| {
register_hll!(
32,
"proxy_endpoints_auth_rate_limits",
"Number of endpoints affected by authentication rate limits",
)
.unwrap()
});
pub static AUTH_RATE_LIMIT_HITS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"proxy_requests_auth_rate_limits_total",
"Number of connection requests affected by authentication rate limits",
)
.unwrap()
});

View File

@@ -7,7 +7,6 @@ pub mod handshake;
pub mod passthrough;
pub mod retry;
pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
use crate::{
auth,
@@ -16,15 +15,16 @@ use crate::{
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{Metrics, NumClientConnectionsGuard},
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
protocol2::WithClientIp,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey, Normalize,
EndpointCacheKey,
};
use futures::TryFutureExt;
use itertools::Itertools;
use metrics::IntCounterPairGuard;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
@@ -79,10 +79,9 @@ pub async fn task_main(
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let conn_gauge = NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&["tcp"])
.guard();
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
@@ -114,12 +113,7 @@ pub async fn task_main(
},
};
let mut ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
let span = ctx.span.clone();
let res = handle_client(
@@ -243,23 +237,19 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
conn_gauge: IntCounterPairGuard,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!(
protocol = %ctx.protocol,
"handling interactive connection from client"
);
info!("handling interactive connection from client");
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol;
// let _client_gauge = metrics.client_connections.guard(proto);
let _request_gauge = metrics.connection_requests.guard(proto);
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[proto])
.guard();
let tls = config.tls_config.as_ref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(stream, mode.handshake_tls(tls), record_handshake_error);
let do_handshake = handshake(stream, mode.handshake_tls(tls));
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
@@ -290,7 +280,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep.normalize(), 1) {
if !endpoint_rate_limiter.check(ep, 1) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await?;

View File

@@ -4,7 +4,7 @@ use crate::{
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
context::RequestMonitoring,
error::ReportableError,
metrics::{ConnectionFailureKind, Metrics},
metrics::NUM_CONNECTION_FAILURES,
proxy::{
retry::{retry_after, ShouldRetry},
wake_compute::wake_compute,
@@ -27,10 +27,10 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
warn!("invalidating stalled compute node info cache entry");
}
let label = match is_cached {
true => ConnectionFailureKind::ComputeCached,
false => ConnectionFailureKind::ComputeUncached,
true => "compute_cached",
false => "compute_uncached",
};
Metrics::get().proxy.connection_failures_total.inc(label);
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
node_info.invalidate()
}

View File

@@ -41,7 +41,7 @@ where
}
#[tracing::instrument(skip_all)]
pub async fn copy_bidirectional_client_compute<Client, Compute>(
pub(super) async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,
) -> Result<(u64, u64), std::io::Error>

View File

@@ -63,7 +63,6 @@ pub enum HandshakeData<S> {
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -96,9 +95,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
return Err(HandshakeError::EarlyData);
}
let tls_stream = raw
.upgrade(tls.to_server_config(), record_handshake_error)
.await?;
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver

View File

@@ -2,10 +2,11 @@ use crate::{
cancellation,
compute::PostgresConnection,
console::messages::MetricsAuxInfo,
metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard},
metrics::NUM_BYTES_PROXIED_COUNTER,
stream::Stream,
usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS},
};
use metrics::IntCounterPairGuard;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
@@ -22,25 +23,24 @@ pub async fn proxy_pass(
branch_id: aux.branch_id,
});
let metrics = &Metrics::get().proxy.io_bytes;
let m_sent = metrics.with_labels(Direction::Tx);
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
metrics.get_metric(m_sent).inc_by(cnt as u64);
m_sent.inc_by(cnt as u64);
usage.record_egress(cnt as u64);
},
);
let m_recv = metrics.with_labels(Direction::Rx);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
metrics.get_metric(m_recv).inc_by(cnt as u64);
m_recv.inc_by(cnt as u64);
},
);
@@ -60,8 +60,8 @@ pub struct ProxyPassthrough<P, S> {
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
pub req: NumConnectionRequestsGuard<'static>,
pub conn: NumClientConnectionsGuard<'static>,
pub req: IntCounterPairGuard,
pub conn: IntCounterPairGuard,
pub cancel: cancellation::Session<P>,
}

View File

@@ -175,7 +175,7 @@ async fn dummy_proxy(
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let client = WithClientIp::new(client);
let mut stream = match handshake(client, tls.as_ref(), false).await? {
let mut stream = match handshake(client, tls.as_ref()).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};

View File

@@ -34,10 +34,7 @@ async fn proxy_mitm(
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
let (end_client, startup) = match handshake(client1, Some(&server_config1), false)
.await
.unwrap()
{
let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};

View File

@@ -1,6 +1,6 @@
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
use crate::context::RequestMonitoring;
use crate::metrics::{ConnectionFailuresBreakdownGroup, Metrics, WakeupFailureKind};
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
use crate::proxy::retry::retry_after;
use hyper::StatusCode;
use std::ops::ControlFlow;
@@ -57,46 +57,39 @@ pub fn handle_try_wake(
fn report_error(e: &WakeComputeError, retry: bool) {
use crate::console::errors::ApiError;
let retry = bool_to_str(retry);
let kind = match e {
WakeComputeError::BadComputeAddress(_) => WakeupFailureKind::BadComputeAddress,
WakeComputeError::ApiError(ApiError::Transport(_)) => WakeupFailureKind::ApiTransportError,
WakeComputeError::BadComputeAddress(_) => "bad_compute_address",
WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
ref text,
}) if text.contains("written data quota exceeded")
|| text.contains("the limit for current plan reached") =>
{
WakeupFailureKind::QuotaExceeded
"quota_exceeded"
}
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::UNPROCESSABLE_ENTITY,
ref text,
}) if text.contains("compute time quota of non-primary branches is exceeded") => {
WakeupFailureKind::QuotaExceeded
"quota_exceeded"
}
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
..
}) => WakeupFailureKind::ApiConsoleLocked,
}) => "api_console_locked",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::BAD_REQUEST,
..
}) => WakeupFailureKind::ApiConsoleBadRequest,
}) => "api_console_bad_request",
WakeComputeError::ApiError(ApiError::Console { status, .. })
if status.is_server_error() =>
{
WakeupFailureKind::ApiConsoleOtherServerError
"api_console_other_server_error"
}
WakeComputeError::ApiError(ApiError::Console { .. }) => {
WakeupFailureKind::ApiConsoleOtherError
}
WakeComputeError::TimeoutError => WakeupFailureKind::TimeoutError,
WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
WakeComputeError::TimeoutError => "timeout_error",
};
Metrics::get()
.proxy
.connection_failures_breakdown
.inc(ConnectionFailuresBreakdownGroup {
kind,
retry: retry.into(),
});
NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
}

View File

@@ -1,2 +1,7 @@
mod aimd;
mod limit_algorithm;
mod limiter;
pub use limiter::{BucketRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo};
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::Limiter;
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};

View File

@@ -0,0 +1,166 @@
use std::usize;
use async_trait::async_trait;
use super::limit_algorithm::{AimdConfig, LimitAlgorithm, Sample};
use super::limiter::Outcome;
/// Loss-based congestion avoidance.
///
/// Additive-increase, multiplicative decrease.
///
/// Adds available currency when:
/// 1. no load-based errors are observed, and
/// 2. the utilisation of the current limit is high.
///
/// Reduces available concurrency by a factor when load-based errors are detected.
pub struct Aimd {
min_limit: usize,
max_limit: usize,
decrease_factor: f32,
increase_by: usize,
min_utilisation_threshold: f32,
}
impl Aimd {
pub fn new(config: AimdConfig) -> Self {
Self {
min_limit: config.aimd_min_limit,
max_limit: config.aimd_max_limit,
decrease_factor: config.aimd_decrease_factor,
increase_by: config.aimd_increase_by,
min_utilisation_threshold: config.aimd_min_utilisation_threshold,
}
}
}
#[async_trait]
impl LimitAlgorithm for Aimd {
async fn update(&mut self, old_limit: usize, sample: Sample) -> usize {
use Outcome::*;
match sample.outcome {
Success => {
let utilisation = sample.in_flight as f32 / old_limit as f32;
if utilisation > self.min_utilisation_threshold {
let limit = old_limit + self.increase_by;
limit.clamp(self.min_limit, self.max_limit)
} else {
old_limit
}
}
Overload => {
let limit = old_limit as f32 * self.decrease_factor;
// Floor instead of round, so the limit reduces even with small numbers.
// E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1
let limit = limit.floor() as usize;
limit.clamp(self.min_limit, self.max_limit)
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::sync::Notify;
use super::*;
use crate::rate_limiter::{Limiter, RateLimiterConfig};
#[tokio::test]
async fn should_decrease_limit_on_overload() {
let config = RateLimiterConfig {
initial_limit: 10,
aimd_config: Some(AimdConfig {
aimd_decrease_factor: 0.5,
..Default::default()
}),
disable: false,
..Default::default()
};
let release_notifier = Arc::new(Notify::new());
let limiter = Limiter::new(config).with_release_notifier(release_notifier.clone());
let token = limiter.try_acquire().unwrap();
limiter.release(token, Some(Outcome::Overload)).await;
release_notifier.notified().await;
assert_eq!(limiter.state().limit(), 5, "overload: decrease");
}
#[tokio::test]
async fn should_increase_limit_on_success_when_using_gt_util_threshold() {
let config = RateLimiterConfig {
initial_limit: 4,
aimd_config: Some(AimdConfig {
aimd_decrease_factor: 0.5,
aimd_min_utilisation_threshold: 0.5,
aimd_increase_by: 1,
..Default::default()
}),
disable: false,
..Default::default()
};
let limiter = Limiter::new(config);
let token = limiter.try_acquire().unwrap();
let _token = limiter.try_acquire().unwrap();
let _token = limiter.try_acquire().unwrap();
limiter.release(token, Some(Outcome::Success)).await;
assert_eq!(limiter.state().limit(), 5, "success: increase");
}
#[tokio::test]
async fn should_not_change_limit_on_success_when_using_lt_util_threshold() {
let config = RateLimiterConfig {
initial_limit: 4,
aimd_config: Some(AimdConfig {
aimd_decrease_factor: 0.5,
aimd_min_utilisation_threshold: 0.5,
..Default::default()
}),
disable: false,
..Default::default()
};
let limiter = Limiter::new(config);
let token = limiter.try_acquire().unwrap();
limiter.release(token, Some(Outcome::Success)).await;
assert_eq!(
limiter.state().limit(),
4,
"success: ignore when < half limit"
);
}
#[tokio::test]
async fn should_not_change_limit_when_no_outcome() {
let config = RateLimiterConfig {
initial_limit: 10,
aimd_config: Some(AimdConfig {
aimd_decrease_factor: 0.5,
aimd_min_utilisation_threshold: 0.5,
..Default::default()
}),
disable: false,
..Default::default()
};
let limiter = Limiter::new(config);
let token = limiter.try_acquire().unwrap();
limiter.release(token, None).await;
assert_eq!(limiter.state().limit(), 10, "ignore");
}
}

View File

@@ -0,0 +1,98 @@
//! Algorithms for controlling concurrency limits.
use async_trait::async_trait;
use std::time::Duration;
use super::{limiter::Outcome, Aimd};
/// An algorithm for controlling a concurrency limit.
#[async_trait]
pub trait LimitAlgorithm: Send + Sync + 'static {
/// Update the concurrency limit in response to a new job completion.
async fn update(&mut self, old_limit: usize, sample: Sample) -> usize;
}
/// The result of a job (or jobs), including the [Outcome] (loss) and latency (delay).
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Sample {
pub(crate) latency: Duration,
/// Jobs in flight when the sample was taken.
pub(crate) in_flight: usize,
pub(crate) outcome: Outcome,
}
#[derive(Clone, Copy, Debug, Default, clap::ValueEnum)]
pub enum RateLimitAlgorithm {
Fixed,
#[default]
Aimd,
}
pub struct Fixed;
#[async_trait]
impl LimitAlgorithm for Fixed {
async fn update(&mut self, old_limit: usize, _sample: Sample) -> usize {
old_limit
}
}
#[derive(Clone, Copy, Debug)]
pub struct RateLimiterConfig {
pub disable: bool,
pub algorithm: RateLimitAlgorithm,
pub timeout: Duration,
pub initial_limit: usize,
pub aimd_config: Option<AimdConfig>,
}
impl RateLimiterConfig {
pub fn create_rate_limit_algorithm(self) -> Box<dyn LimitAlgorithm> {
match self.algorithm {
RateLimitAlgorithm::Fixed => Box::new(Fixed),
RateLimitAlgorithm::Aimd => Box::new(Aimd::new(self.aimd_config.unwrap())), // For aimd algorithm config is mandatory.
}
}
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
disable: true,
algorithm: RateLimitAlgorithm::Aimd,
timeout: Duration::from_secs(1),
initial_limit: 100,
aimd_config: Some(AimdConfig::default()),
}
}
}
#[derive(clap::Parser, Clone, Copy, Debug)]
pub struct AimdConfig {
/// Minimum limit for AIMD algorithm. Makes sense only if `rate_limit_algorithm` is `Aimd`.
#[clap(long, default_value_t = 1)]
pub aimd_min_limit: usize,
/// Maximum limit for AIMD algorithm. Makes sense only if `rate_limit_algorithm` is `Aimd`.
#[clap(long, default_value_t = 1500)]
pub aimd_max_limit: usize,
/// Increase AIMD increase by value in case of success. Makes sense only if `rate_limit_algorithm` is `Aimd`.
#[clap(long, default_value_t = 10)]
pub aimd_increase_by: usize,
/// Decrease AIMD decrease by value in case of timout/429. Makes sense only if `rate_limit_algorithm` is `Aimd`.
#[clap(long, default_value_t = 0.9)]
pub aimd_decrease_factor: f32,
/// A threshold below which the limit won't be increased. Makes sense only if `rate_limit_algorithm` is `Aimd`.
#[clap(long, default_value_t = 0.8)]
pub aimd_min_utilisation_threshold: f32,
}
impl Default for AimdConfig {
fn default() -> Self {
Self {
aimd_min_limit: 1,
aimd_max_limit: 1500,
aimd_increase_by: 10,
aimd_decrease_factor: 0.9,
aimd_min_utilisation_threshold: 0.8,
}
}
}

View File

@@ -2,9 +2,10 @@ use std::{
borrow::Cow,
collections::hash_map::RandomState,
hash::{BuildHasher, Hash},
net::IpAddr,
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
Arc, Mutex,
},
};
@@ -12,18 +13,24 @@ use anyhow::bail;
use dashmap::DashMap;
use itertools::Itertools;
use rand::{rngs::StdRng, Rng, SeedableRng};
use tokio::time::{Duration, Instant};
use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
use tokio::time::{timeout, Duration, Instant};
use tracing::info;
use crate::EndpointId;
use crate::{intern::EndpointIdInt, EndpointId};
pub struct GlobalRateLimiter {
use super::{
limit_algorithm::{LimitAlgorithm, Sample},
RateLimiterConfig,
};
pub struct RedisRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,
info: &'static [RateBucketInfo],
}
impl GlobalRateLimiter {
pub fn new(info: Vec<RateBucketInfo>) -> Self {
impl RedisRateLimiter {
pub fn new(info: &'static [RateBucketInfo]) -> Self {
Self {
data: vec![
RateBucket {
@@ -43,7 +50,7 @@ impl GlobalRateLimiter {
let should_allow_request = self
.data
.iter_mut()
.zip(&self.info)
.zip(self.info)
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
if should_allow_request {
@@ -68,6 +75,9 @@ impl GlobalRateLimiter {
// I went with a more expensive way that yields user-friendlier error messages.
pub type EndpointRateLimiter = BucketRateLimiter<EndpointId, StdRng, RandomState>;
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, IpAddr), StdRng, RandomState>;
pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
map: DashMap<Key, Vec<RateBucket>, Hasher>,
info: Cow<'static, [RateBucketInfo]>,
@@ -139,6 +149,19 @@ impl RateBucketInfo {
Self::new(100, Duration::from_secs(600)),
];
/// All of these are per endpoint-ip pair.
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
///
/// First bucket: 300mcpus total per endpoint-ip pair
/// * 1228800 requests per second with 1 hash rounds. (endpoint rate limiter will catch this first)
/// * 300 requests per second with 4096 hash rounds.
/// * 2 requests per second with 600000 hash rounds.
pub const DEFAULT_AUTH_SET: [Self; 3] = [
Self::new(300 * 4096, Duration::from_secs(1)),
Self::new(200 * 4096, Duration::from_secs(60)),
Self::new(100 * 4096, Duration::from_secs(600)),
];
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
info.sort_unstable_by_key(|info| info.interval);
let invalid = info
@@ -236,16 +259,419 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
}
}
/// Limits the number of concurrent jobs.
///
/// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
/// token once the job is finished.
///
/// The limit will be automatically adjusted based on observed latency (delay) and/or failures
/// caused by overload (loss).
pub struct Limiter {
limit_algo: AsyncMutex<Box<dyn LimitAlgorithm>>,
semaphore: std::sync::Arc<Semaphore>,
config: RateLimiterConfig,
// ONLY WRITE WHEN LIMIT_ALGO IS LOCKED
limits: AtomicUsize,
// ONLY USE ATOMIC ADD/SUB
in_flight: Arc<AtomicUsize>,
#[cfg(test)]
notifier: Option<std::sync::Arc<tokio::sync::Notify>>,
}
/// A concurrency token, required to run a job.
///
/// Release the token back to the [Limiter] after the job is complete.
#[derive(Debug)]
pub struct Token<'t> {
permit: Option<tokio::sync::SemaphorePermit<'t>>,
start: Instant,
in_flight: Arc<AtomicUsize>,
}
/// A snapshot of the state of the [Limiter].
///
/// Not guaranteed to be consistent under high concurrency.
#[derive(Debug, Clone, Copy)]
pub struct LimiterState {
limit: usize,
in_flight: usize,
}
/// Whether a job succeeded or failed as a result of congestion/overload.
///
/// Errors not considered to be caused by overload should be ignored.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Outcome {
/// The job succeeded, or failed in a way unrelated to overload.
Success,
/// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
/// was observed.
Overload,
}
impl Outcome {
fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self {
match error {
reqwest_middleware::Error::Middleware(_) => Outcome::Success,
reqwest_middleware::Error::Reqwest(e) => {
if let Some(status) = e.status() {
if status.is_server_error()
|| reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status
{
Outcome::Overload
} else {
Outcome::Success
}
} else {
Outcome::Success
}
}
}
}
fn from_reqwest_response(response: &reqwest::Response) -> Self {
if response.status().is_server_error()
|| response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS
{
Outcome::Overload
} else {
Outcome::Success
}
}
}
impl Limiter {
/// Create a limiter with a given limit control algorithm.
pub fn new(config: RateLimiterConfig) -> Self {
assert!(config.initial_limit > 0);
Self {
limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()),
semaphore: Arc::new(Semaphore::new(config.initial_limit)),
config,
limits: AtomicUsize::new(config.initial_limit),
in_flight: Arc::new(AtomicUsize::new(0)),
#[cfg(test)]
notifier: None,
}
}
// pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self {
// assert!(initial_limit > 0);
// Self {
// limit_algo: AsyncMutex::new(limit_algorithm),
// semaphore: Arc::new(Semaphore::new(initial_limit)),
// timeout,
// limits: AtomicUsize::new(initial_limit),
// in_flight: Arc::new(AtomicUsize::new(0)),
// #[cfg(test)]
// notifier: None,
// }
// }
/// In some cases [Token]s are acquired asynchronously when updating the limit.
#[cfg(test)]
pub fn with_release_notifier(mut self, n: std::sync::Arc<tokio::sync::Notify>) -> Self {
self.notifier = Some(n);
self
}
/// Try to immediately acquire a concurrency [Token].
///
/// Returns `None` if there are none available.
pub fn try_acquire(&self) -> Option<Token> {
let result = if self.config.disable {
// If the rate limiter is disabled, we can always acquire a token.
Some(Token::new(None, self.in_flight.clone()))
} else {
self.semaphore
.try_acquire()
.map(|permit| Token::new(Some(permit), self.in_flight.clone()))
.ok()
};
if result.is_some() {
self.in_flight.fetch_add(1, Ordering::AcqRel);
}
result
}
/// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
///
/// Returns `None` if there are none available after `duration`.
pub async fn acquire_timeout(&self, duration: Duration) -> Option<Token<'_>> {
info!("acquiring token: {:?}", self.semaphore.available_permits());
let result = if self.config.disable {
// If the rate limiter is disabled, we can always acquire a token.
Some(Token::new(None, self.in_flight.clone()))
} else {
match timeout(duration, self.semaphore.acquire()).await {
Ok(maybe_permit) => maybe_permit
.map(|permit| Token::new(Some(permit), self.in_flight.clone()))
.ok(),
Err(_) => None,
}
};
if result.is_some() {
self.in_flight.fetch_add(1, Ordering::AcqRel);
}
result
}
/// Return the concurrency [Token], along with the outcome of the job.
///
/// The [Outcome] of the job, and the time taken to perform it, may be used
/// to update the concurrency limit.
///
/// Set the outcome to `None` to ignore the job.
pub async fn release(&self, mut token: Token<'_>, outcome: Option<Outcome>) {
tracing::info!("outcome is {:?}", outcome);
let in_flight = self.in_flight.load(Ordering::Acquire);
let old_limit = self.limits.load(Ordering::Acquire);
let available = if self.config.disable {
0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0.
} else {
self.semaphore.available_permits()
};
let total = in_flight + available;
let mut algo = self.limit_algo.lock().await;
let new_limit = if let Some(outcome) = outcome {
let sample = Sample {
latency: token.start.elapsed(),
in_flight,
outcome,
};
algo.update(old_limit, sample).await
} else {
old_limit
};
tracing::info!("new limit is {}", new_limit);
let actual_limit = if new_limit < total {
token.forget();
total.saturating_sub(1)
} else {
if !self.config.disable {
self.semaphore.add_permits(new_limit.saturating_sub(total));
}
new_limit
};
crate::metrics::RATE_LIMITER_LIMIT
.with_label_values(&["expected"])
.set(new_limit as i64);
crate::metrics::RATE_LIMITER_LIMIT
.with_label_values(&["actual"])
.set(actual_limit as i64);
self.limits.store(new_limit, Ordering::Release);
#[cfg(test)]
if let Some(n) = &self.notifier {
n.notify_one();
}
}
/// The current state of the limiter.
pub fn state(&self) -> LimiterState {
let limit = self.limits.load(Ordering::Relaxed);
let in_flight = self.in_flight.load(Ordering::Relaxed);
LimiterState { limit, in_flight }
}
}
impl<'t> Token<'t> {
fn new(permit: Option<SemaphorePermit<'t>>, in_flight: Arc<AtomicUsize>) -> Self {
Self {
permit,
start: Instant::now(),
in_flight,
}
}
pub fn forget(&mut self) {
if let Some(permit) = self.permit.take() {
permit.forget();
}
}
}
impl Drop for Token<'_> {
fn drop(&mut self) {
self.in_flight.fetch_sub(1, Ordering::AcqRel);
}
}
impl LimiterState {
/// The current concurrency limit.
pub fn limit(&self) -> usize {
self.limit
}
/// The number of jobs in flight.
pub fn in_flight(&self) -> usize {
self.in_flight
}
}
#[async_trait::async_trait]
impl reqwest_middleware::Middleware for Limiter {
async fn handle(
&self,
req: reqwest::Request,
extensions: &mut task_local_extensions::Extensions,
next: reqwest_middleware::Next<'_>,
) -> reqwest_middleware::Result<reqwest::Response> {
let start = Instant::now();
let token = self
.acquire_timeout(self.config.timeout)
.await
.ok_or_else(|| {
reqwest_middleware::Error::Middleware(
// TODO: Should we map it into user facing errors?
crate::console::errors::ApiError::Console {
status: crate::http::StatusCode::TOO_MANY_REQUESTS,
text: "Too many requests".into(),
}
.into(),
)
})?;
info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane");
crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64());
match next.run(req, extensions).await {
Ok(response) => {
self.release(token, Some(Outcome::from_reqwest_response(&response)))
.await;
Ok(response)
}
Err(e) => {
self.release(token, Some(Outcome::from_reqwest_error(&e)))
.await;
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use std::{hash::BuildHasherDefault, time::Duration};
use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration};
use futures::{task::noop_waker_ref, Future};
use rand::SeedableRng;
use rustc_hash::FxHasher;
use tokio::time;
use super::{BucketRateLimiter, EndpointRateLimiter};
use crate::{rate_limiter::RateBucketInfo, EndpointId};
use super::{BucketRateLimiter, EndpointRateLimiter, Limiter, Outcome};
use crate::{
rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
EndpointId,
};
#[tokio::test]
async fn it_works() {
let config = super::RateLimiterConfig {
algorithm: RateLimitAlgorithm::Fixed,
timeout: Duration::from_secs(1),
initial_limit: 10,
disable: false,
..Default::default()
};
let limiter = Limiter::new(config);
let token = limiter.try_acquire().unwrap();
limiter.release(token, Some(Outcome::Success)).await;
assert_eq!(limiter.state().limit(), 10);
}
#[tokio::test]
async fn is_fair() {
let config = super::RateLimiterConfig {
algorithm: RateLimitAlgorithm::Fixed,
timeout: Duration::from_secs(1),
initial_limit: 1,
disable: false,
..Default::default()
};
let limiter = Limiter::new(config);
// === TOKEN 1 ===
let token1 = limiter.try_acquire().unwrap();
let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
assert!(
token2_fut
.as_mut()
.poll(&mut Context::from_waker(noop_waker_ref()))
.is_pending(),
"token is acquired by token1"
);
let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1)));
assert!(
token3_fut
.as_mut()
.poll(&mut Context::from_waker(noop_waker_ref()))
.is_pending(),
"token is acquired by token1"
);
limiter.release(token1, Some(Outcome::Success)).await;
// === END TOKEN 1 ===
// === TOKEN 2 ===
assert!(
limiter.try_acquire().is_none(),
"token is acquired by token2"
);
assert!(
token3_fut
.as_mut()
.poll(&mut Context::from_waker(noop_waker_ref()))
.is_pending(),
"token is acquired by token2"
);
let token2 = token2_fut.await.unwrap();
limiter.release(token2, Some(Outcome::Success)).await;
// === END TOKEN 2 ===
// === TOKEN 3 ===
assert!(
limiter.try_acquire().is_none(),
"token is acquired by token3"
);
let token3 = token3_fut.await.unwrap();
limiter.release(token3, Some(Outcome::Success)).await;
// === END TOKEN 3 ===
// === TOKEN 4 ===
let token4 = limiter.try_acquire().unwrap();
limiter.release(token4, Some(Outcome::Success)).await;
}
#[tokio::test]
async fn disable() {
let config = super::RateLimiterConfig {
algorithm: RateLimitAlgorithm::Fixed,
timeout: Duration::from_secs(1),
initial_limit: 1,
disable: true,
..Default::default()
};
let limiter = Limiter::new(config);
// === TOKEN 1 ===
let token1 = limiter.try_acquire().unwrap();
let token2 = limiter.try_acquire().unwrap();
let state = limiter.state();
assert_eq!(state.limit(), 1);
assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected.
limiter.release(token1, None).await;
limiter.release(token2, None).await;
}
#[test]
fn rate_bucket_rpi() {
@@ -347,4 +773,31 @@ mod tests {
}
assert!(limiter.map.len() < 150_000);
}
#[test]
fn test_default_auth_set() {
// these values used to exceed u32::MAX
assert_eq!(
RateBucketInfo::DEFAULT_AUTH_SET,
[
RateBucketInfo {
interval: Duration::from_secs(1),
max_rpi: 300 * 4096,
},
RateBucketInfo {
interval: Duration::from_secs(60),
max_rpi: 200 * 4096 * 60,
},
RateBucketInfo {
interval: Duration::from_secs(600),
max_rpi: 100 * 4096 * 600,
}
]
);
for x in RateBucketInfo::DEFAULT_AUTH_SET {
let y = x.to_string().parse().unwrap();
assert_eq!(x, y);
}
}
}

View File

@@ -5,7 +5,7 @@ use redis::AsyncCommands;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::{
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
@@ -80,7 +80,7 @@ impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
pub struct RedisPublisherClient {
client: ConnectionWithCredentialsProvider,
region_id: String,
limiter: GlobalRateLimiter,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
@@ -92,7 +92,7 @@ impl RedisPublisherClient {
Ok(Self {
client,
region_id,
limiter: GlobalRateLimiter::new(info.into()),
limiter: RedisRateLimiter::new(info),
})
}

View File

@@ -77,14 +77,10 @@ impl ConnectionWithCredentialsProvider {
}
}
async fn ping(con: &mut MultiplexedConnection) -> RedisResult<()> {
redis::cmd("PING").query_async(con).await
}
pub async fn connect(&mut self) -> anyhow::Result<()> {
let _guard = self.mutex.lock().await;
if let Some(con) = self.con.as_mut() {
match Self::ping(con).await {
match redis::cmd("PING").query_async(con).await {
Ok(()) => {
return Ok(());
}
@@ -100,7 +96,7 @@ impl ConnectionWithCredentialsProvider {
if let Some(f) = self.refresh_token_task.take() {
f.abort()
}
let mut con = self
let con = self
.get_client()
.await?
.get_multiplexed_tokio_connection()
@@ -113,14 +109,6 @@ impl ConnectionWithCredentialsProvider {
});
self.refresh_token_task = Some(f);
}
match Self::ping(&mut con).await {
Ok(()) => {
info!("Connection succesfully established");
}
Err(e) => {
error!("Connection is broken. Error during PING: {e:?}");
}
}
self.con = Some(con);
Ok(())
}

View File

@@ -11,7 +11,7 @@ use crate::{
cache::project_info::ProjectInfoCache,
cancellation::{CancelMap, CancellationHandler},
intern::{ProjectIdInt, RoleNameInt},
metrics::{Metrics, RedisErrors},
metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES},
};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
@@ -104,9 +104,9 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
let msg: Notification = match serde_json::from_str(&payload) {
Ok(msg) => msg,
Err(e) => {
Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
channel: msg.get_channel_name(),
});
REDIS_BROKEN_MESSAGES
.with_label_values(&[msg.get_channel_name()])
.inc();
tracing::error!("broken message: {e}");
return Ok(());
}
@@ -183,7 +183,7 @@ where
cache,
Arc::new(CancellationHandler::<()>::new(
cancel_map,
crate::metrics::CancellationSource::FromRedis,
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
)),
region_id,
);

View File

@@ -32,7 +32,7 @@ use tokio_util::task::TaskTracker;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
use crate::protocol2::WithClientIp;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
@@ -156,10 +156,9 @@ async fn connection_handler(
) {
let session_id = uuid::Uuid::new_v4();
let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&["http"])
.guard();
// handle PROXY protocol
let mut conn = WithClientIp::new(conn);
@@ -172,10 +171,6 @@ async fn connection_handler(
};
let peer_addr = peer.unwrap_or(peer_addr).ip();
let has_private_peer_addr = match peer_addr {
IpAddr::V4(ip) => ip.is_private(),
_ => false,
};
info!(?session_id, %peer_addr, "accepted new TCP connection");
// try upgrade to TLS, but with a timeout.
@@ -186,17 +181,13 @@ async fn connection_handler(
}
// The handshake failed
Ok(Err(e)) => {
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
TLS_HANDSHAKE_FAILURES.inc();
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
}
// The handshake timed out
Err(e) => {
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
TLS_HANDSHAKE_FAILURES.inc();
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
}
@@ -283,13 +274,7 @@ async fn request_handler(
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Ws,
&config.region,
);
let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
let span = ctx.span.clone();
info!(parent: &span, "performing websocket upgrade");
@@ -317,12 +302,7 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Http,
&config.region,
);
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let span = ctx.span.clone();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)

View File

@@ -6,7 +6,7 @@ use tracing::{field::display, info};
use crate::{
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
compute,
config::{AuthenticationConfig, ProxyConfig},
config::ProxyConfig,
console::{
errors::{GetAuthInfoError, WakeComputeError},
CachedNodeInfo,
@@ -27,7 +27,6 @@ impl PoolingBackend {
pub async fn authenticate(
&self,
ctx: &mut RequestMonitoring,
config: &AuthenticationConfig,
conn_info: &ConnInfo,
) -> Result<ComputeCredentials, AuthError> {
let user_info = conn_info.user_info.clone();
@@ -44,7 +43,6 @@ impl PoolingBackend {
let secret = match cached_secret.value.clone() {
Some(secret) => self.config.authentication_config.check_rate_limit(
ctx,
config,
secret,
&user_info.endpoint,
true,

View File

@@ -1,5 +1,6 @@
use dashmap::DashMap;
use futures::{future::poll_fn, Future};
use metrics::IntCounterPairGuard;
use parking_lot::RwLock;
use rand::Rng;
use smallvec::SmallVec;
@@ -15,13 +16,13 @@ use std::{
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::metrics::{ENDPOINT_POOLS, GC_LATENCY, NUM_OPEN_CLIENTS_IN_HTTP_POOL};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
auth::backend::ComputeUserInfo, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE,
DbName, EndpointCacheKey, RoleName,
};
use tracing::{debug, error, warn, Span};
@@ -77,7 +78,7 @@ pub struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
_guard: HttpEndpointPoolsGuard<'static>,
_guard: IntCounterPairGuard,
global_connections_count: Arc<AtomicUsize>,
global_pool_size_max_conns: usize,
}
@@ -109,11 +110,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
}
*total_conns -= removed;
removed > 0
@@ -159,11 +156,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
pool.total_conns += 1;
pool.global_connections_count
.fetch_add(1, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
NUM_OPEN_CLIENTS_IN_HTTP_POOL.inc();
}
pool.total_conns
@@ -183,11 +176,7 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
if self.total_conns > 0 {
self.global_connections_count
.fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(self.total_conns as i64);
}
}
}
@@ -226,11 +215,7 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
removed += 1;
}
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64);
conn
}
}
@@ -318,10 +303,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let timer = GC_LATENCY.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
@@ -349,7 +331,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
let new_len = shard.len();
drop(shard);
timer.observe();
timer.observe_duration();
// Do logging outside of the lock.
if clients_removed > 0 {
@@ -357,11 +339,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
@@ -432,7 +410,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
pools: HashMap::new(),
total_conns: 0,
max_conns: self.config.pool_options.max_conns_per_endpoint,
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
_guard: ENDPOINT_POOLS.guard(),
global_connections_count: self.global_connections_count.clone(),
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
}));
@@ -472,7 +450,9 @@ pub fn poll_client<C: ClientInnerExt>(
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol);
let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
let mut session_id = ctx.session_id;
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
@@ -489,32 +469,15 @@ pub fn poll_client<C: ClientInnerExt>(
let db_user = conn_info.db_and_user();
let idle = global_pool.get_idle_timeout();
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let mut idle_timeout = pin!(tokio::time::sleep(idle));
let mut cancelled = pin!(cancelled);
poll_fn(move |cx| {
if cancelled.as_mut().poll(cx).is_ready() {
info!("connection dropped");
return Poll::Ready(())
}
match rx.has_changed() {
Ok(true) => {
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
Err(_) => {
info!("connection dropped");
return Poll::Ready(())
}
_ => {}
if matches!(rx.has_changed(), Ok(true)) {
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
// 5 minute idle connection timeout
@@ -569,7 +532,6 @@ pub fn poll_client<C: ClientInnerExt>(
let inner = ClientInner {
inner: client,
session: tx,
cancel,
aux,
conn_id,
};
@@ -579,18 +541,10 @@ pub fn poll_client<C: ClientInnerExt>(
struct ClientInner<C: ClientInnerExt> {
inner: C,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
aux: MetricsAuxInfo,
conn_id: uuid::Uuid,
}
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
self.cancel.cancel();
}
}
pub trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
@@ -743,7 +697,6 @@ mod tests {
ClientInner {
inner: client,
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
cancel: CancellationToken::new(),
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),

View File

@@ -43,8 +43,8 @@ use crate::context::RequestMonitoring;
use crate::error::ErrorKind;
use crate::error::ReportableError;
use crate::error::UserFacingError;
use crate::metrics::HttpDirection;
use crate::metrics::Metrics;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
@@ -494,11 +494,10 @@ async fn handle_inner(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get().proxy.connection_requests.guard(ctx.protocol);
info!(
protocol = %ctx.protocol,
"handling interactive connection from client"
);
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
info!("handling interactive connection from client");
//
// Determine the destination and connection params
@@ -521,10 +520,9 @@ async fn handle_inner(
None => MAX_REQUEST_SIZE + 1,
};
info!(request_content_length, "request size in bytes");
Metrics::get()
.proxy
.http_conn_content_length_bytes
.observe(HttpDirection::Request, request_content_length as f64);
HTTP_CONTENT_LENGTH
.with_label_values(&["request"])
.observe(request_content_length as f64);
// we don't have a streaming request support yet so this is to prevent OOM
// from a malicious user sending an extremely large request body
@@ -541,9 +539,7 @@ async fn handle_inner(
.map_err(SqlOverHttpError::from);
let authenticate_and_connect = async {
let keys = backend
.authenticate(ctx, &config.authentication_config, &conn_info)
.await?;
let keys = backend.authenticate(ctx, &conn_info).await?;
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
@@ -611,10 +607,9 @@ async fn handle_inner(
// count the egress bytes - we miss the TLS and header overhead but oh well...
// moving this later in the stack is going to be a lot of effort and ehhhh
metrics.record_egress(len as u64);
Metrics::get()
.proxy
.http_conn_content_length_bytes
.observe(HttpDirection::Response, len as f64);
HTTP_CONTENT_LENGTH
.with_label_values(&["response"])
.observe(len as f64);
Ok(response)
}

View File

@@ -3,7 +3,7 @@ use crate::{
config::ProxyConfig,
context::RequestMonitoring,
error::{io_error, ReportableError},
metrics::Metrics,
metrics::NUM_CLIENT_CONNECTION_GAUGE,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
@@ -139,10 +139,9 @@ pub async fn serve_websocket(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Ws);
let conn_gauge = NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&["ws"])
.guard();
let res = handle_client(
config,

View File

@@ -1,6 +1,6 @@
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::metrics::TLS_HANDSHAKE_FAILURES;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
@@ -223,20 +223,12 @@ pub enum StreamUpgradeError {
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(
self,
cfg: Arc<ServerConfig>,
record_handshake_error: bool,
) -> Result<TlsStream<S>, StreamUpgradeError> {
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
match self {
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
.accept(raw)
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc()
}
})?),
.inspect_err(|_| TLS_HANDSHAKE_FAILURES.inc())?),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}

View File

@@ -495,7 +495,7 @@ mod tests {
use url::Url;
use super::*;
use crate::{http, BranchId, EndpointId};
use crate::{http, rate_limiter::RateLimiterConfig, BranchId, EndpointId};
#[tokio::test]
async fn metrics() {
@@ -525,7 +525,7 @@ mod tests {
tokio::spawn(server);
let metrics = Metrics::default();
let client = http::new_client();
let client = http::new_client(RateLimiterConfig::default());
let endpoint = Url::parse(&format!("http://{addr}")).unwrap();
let now = Utc::now();

View File

@@ -17,8 +17,6 @@ use crate::service::Config;
const SLOWDOWN_DELAY: Duration = Duration::from_secs(5);
const NOTIFY_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
pub(crate) const API_CONCURRENCY: usize = 32;
struct UnshardedComputeHookTenant {
@@ -244,10 +242,6 @@ pub(super) struct ComputeHook {
// This lock is only used in testing enviroments, to serialize calls into neon_lock
neon_local_lock: tokio::sync::Mutex<()>,
// We share a client across all notifications to enable connection re-use etc when
// sending large numbers of notifications
client: reqwest::Client,
}
impl ComputeHook {
@@ -257,18 +251,12 @@ impl ComputeHook {
.clone()
.map(|jwt| format!("Bearer {}", jwt));
let client = reqwest::ClientBuilder::new()
.timeout(NOTIFY_REQUEST_TIMEOUT)
.build()
.expect("Failed to construct HTTP client");
Self {
state: Default::default(),
config,
authorization_header,
neon_local_lock: Default::default(),
api_concurrency: tokio::sync::Semaphore::new(API_CONCURRENCY),
client,
}
}
@@ -322,11 +310,12 @@ impl ComputeHook {
async fn do_notify_iteration(
&self,
client: &reqwest::Client,
url: &String,
reconfigure_request: &ComputeHookNotifyRequest,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let req = self.client.request(Method::PUT, url);
let req = client.request(Method::PUT, url);
let req = if let Some(value) = &self.authorization_header {
req.header(reqwest::header::AUTHORIZATION, value)
} else {
@@ -392,6 +381,8 @@ impl ComputeHook {
reconfigure_request: &ComputeHookNotifyRequest,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let client = reqwest::Client::new();
// We hold these semaphore units across all retries, rather than only across each
// HTTP request: this is to preserve fairness and avoid a situation where a retry might
// time out waiting for a semaphore.
@@ -403,7 +394,7 @@ impl ComputeHook {
.map_err(|_| NotifyError::ShuttingDown)?;
backoff::retry(
|| self.do_notify_iteration(url, reconfigure_request, cancel),
|| self.do_notify_iteration(&client, url, reconfigure_request, cancel),
|e| {
matches!(
e,

View File

@@ -84,20 +84,6 @@ impl std::ops::Add for AffinityScore {
}
}
/// Hint for whether this is a sincere attempt to schedule, or a speculative
/// check for where we _would_ schedule (done during optimization)
#[derive(Debug)]
pub(crate) enum ScheduleMode {
Normal,
Speculative,
}
impl Default for ScheduleMode {
fn default() -> Self {
Self::Normal
}
}
// For carrying state between multiple calls to [`TenantShard::schedule`], e.g. when calling
// it for many shards in the same tenant.
#[derive(Debug, Default)]
@@ -107,8 +93,6 @@ pub(crate) struct ScheduleContext {
/// Specifically how many _attached_ locations are on each node
pub(crate) attached_nodes: HashMap<NodeId, usize>,
pub(crate) mode: ScheduleMode,
}
impl ScheduleContext {
@@ -345,34 +329,27 @@ impl Scheduler {
scores.sort_by_key(|i| (i.1, i.2, i.0));
if scores.is_empty() {
// After applying constraints, no pageservers were left.
if !matches!(context.mode, ScheduleMode::Speculative) {
// If this was not a speculative attempt, log details to understand why we couldn't
// schedule: this may help an engineer understand if some nodes are marked offline
// in a way that's preventing progress.
// After applying constraints, no pageservers were left. We log some detail about
// the state of nodes to help understand why this happened. This is not logged as an error because
// it is legitimately possible for enough nodes to be Offline to prevent scheduling a shard.
tracing::info!("Scheduling failure, while excluding {hard_exclude:?}, node states:");
for (node_id, node) in &self.nodes {
tracing::info!(
"Scheduling failure, while excluding {hard_exclude:?}, node states:"
"Node {node_id}: may_schedule={} shards={}",
node.may_schedule != MaySchedule::No,
node.shard_count
);
for (node_id, node) in &self.nodes {
tracing::info!(
"Node {node_id}: may_schedule={} shards={}",
node.may_schedule != MaySchedule::No,
node.shard_count
);
}
}
return Err(ScheduleError::ImpossibleConstraint);
}
// Lowest score wins
let node_id = scores.first().unwrap().0;
if !matches!(context.mode, ScheduleMode::Speculative) {
tracing::info!(
tracing::info!(
"scheduler selected node {node_id} (elegible nodes {:?}, hard exclude: {hard_exclude:?}, soft exclude: {context:?})",
scores.iter().map(|i| i.0 .0).collect::<Vec<_>>()
);
}
// Note that we do not update shard count here to reflect the scheduling: that
// is IntentState's job when the scheduled location is used.

View File

@@ -11,7 +11,7 @@ use crate::{
id_lock_map::IdLockMap,
persistence::{AbortShardSplitStatus, TenantFilter},
reconciler::ReconcileError,
scheduler::{ScheduleContext, ScheduleMode},
scheduler::ScheduleContext,
};
use anyhow::Context;
use control_plane::storage_controller::{
@@ -2744,7 +2744,7 @@ impl Service {
let mut describe_shards = Vec::new();
for shard in shards {
if shard.tenant_shard_id.is_shard_zero() {
if shard.tenant_shard_id.is_zero() {
shard_zero = Some(shard);
}
@@ -4084,7 +4084,7 @@ impl Service {
let mut reconciles_spawned = 0;
for (tenant_shard_id, shard) in tenants.iter_mut() {
if tenant_shard_id.is_shard_zero() {
if tenant_shard_id.is_zero() {
schedule_context = ScheduleContext::default();
}
@@ -4134,10 +4134,9 @@ impl Service {
let mut work = Vec::new();
for (tenant_shard_id, shard) in tenants.iter() {
if tenant_shard_id.is_shard_zero() {
if tenant_shard_id.is_zero() {
// Reset accumulators on the first shard in a tenant
schedule_context = ScheduleContext::default();
schedule_context.mode = ScheduleMode::Speculative;
tenant_shards.clear();
}

View File

@@ -2449,12 +2449,10 @@ class NeonPageserver(PgProtocol):
if cur_line_no < skip_until_line_no:
cur_line_no += 1
continue
elif contains_re.search(line):
if contains_re.search(line):
# found it!
cur_line_no += 1
return (line, LogCursor(cur_line_no))
else:
cur_line_no += 1
return None
def tenant_attach(

View File

@@ -192,6 +192,9 @@ def test_backward_compatibility(
assert not breaking_changes_allowed, "Breaking changes are allowed by ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE, but the test has passed without any breakage"
# Forward compatibility is broken due to https://github.com/neondatabase/neon/pull/6530
# The test is disabled until the next release deployment
@pytest.mark.xfail
@check_ondisk_data_compatibility_if_enabled
@pytest.mark.xdist_group("compatibility")
@pytest.mark.order(after="test_create_snapshot")

View File

@@ -1,35 +0,0 @@
import pytest
from fixtures.neon_fixtures import (
NeonEnvBuilder,
last_flush_lsn_upload,
)
@pytest.mark.parametrize("kind", ["sync", "async"])
def test_walredo_process_kind_config(neon_env_builder: NeonEnvBuilder, kind: str):
neon_env_builder.pageserver_config_override = f"walredo_process_kind = '{kind}'"
# ensure it starts
env = neon_env_builder.init_start()
# ensure the metric is set
ps_http = env.pageserver.http_client()
metrics = ps_http.get_metrics()
samples = metrics.query_all("pageserver_wal_redo_process_kind")
assert [(s.labels, s.value) for s in samples] == [({"kind": kind}, 1)]
# ensure default tenant's config kind matches
# => write some data to force-spawn walredo
ep = env.endpoints.create_start("main")
with ep.connect() as conn:
with conn.cursor() as cur:
cur.execute("create table foo(bar text)")
cur.execute("insert into foo select from generate_series(1, 100)")
last_flush_lsn_upload(env, ep, env.initial_tenant, env.initial_timeline)
ep.stop()
ep.start()
with ep.connect() as conn:
with conn.cursor() as cur:
cur.execute("select count(*) from foo")
[(count,)] = cur.fetchall()
assert count == 100
status = ps_http.tenant_status(env.initial_tenant)
assert status["walredo"]["process"]["kind"] == kind

View File

@@ -0,0 +1,9 @@
from fixtures.neon_fixtures import NeonEnv
def test_protocol_version(neon_simple_env: NeonEnv):
env = neon_simple_env
endpoint = env.endpoints.create_start("main", config_lines=["neon.protocol_version=1"])
cur = endpoint.connect().cursor()
cur.execute("show neon.protocol_version")
assert cur.fetchone() == ("1",)

View File

@@ -0,0 +1,84 @@
import asyncio
import time
from pathlib import Path
from typing import Iterator
import pytest
from fixtures.neon_fixtures import (
PSQL,
NeonProxy,
)
from fixtures.port_distributor import PortDistributor
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.response import Response
def waiting_handler(status_code: int) -> Response:
# wait more than timeout to make sure that both (two) connections are open.
# It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver.
time.sleep(2)
return Response(status=status_code)
@pytest.fixture(scope="function")
def proxy_with_rate_limit(
port_distributor: PortDistributor,
neon_binpath: Path,
httpserver_listen_address,
test_output_dir: Path,
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
(host, port) = httpserver_listen_address
endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5),
) as proxy:
proxy.start()
yield proxy
@pytest.mark.asyncio
async def test_proxy_rate_limit(
httpserver: HTTPServer,
proxy_with_rate_limit: NeonProxy,
):
uri = "/billing/api/v1/usage_events/proxy_get_role_secret"
# mock control plane service
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: Response(status=200)
)
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: waiting_handler(429)
)
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: waiting_handler(500)
)
psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port)
f = await psql.run("select 42;")
await proxy_with_rate_limit.find_auth_link(uri, f)
# Limit should be 2.
# Run two queries in parallel.
f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;"))
await proxy_with_rate_limit.find_auth_link(uri, f1)
await proxy_with_rate_limit.find_auth_link(uri, f2)
# Now limit should be 0.
f = await psql.run("select 42;")
await proxy_with_rate_limit.find_auth_link(uri, f)
# There last query shouldn't reach the http-server.
assert httpserver.assertions == []

Some files were not shown because too many files have changed in this diff Show More