Compare commits

..

8 Commits

Author SHA1 Message Date
Stas Kelvich
8fbecb1565 bump lsn a bit 2024-01-27 13:48:06 +02:00
Stas Kelvich
715772bdd6 new values 2024-01-27 13:34:28 +02:00
Heikki Linnakangas
1626a1f333 Fix the test
Now it passes on my laptop at least
2024-01-27 13:21:13 +02:00
Heikki Linnakangas
12e39001ce Apply the hack for all timelines of the target tenant
This gives us more flexibility to try it on a branch first
2024-01-27 13:21:13 +02:00
Heikki Linnakangas
65cd16de86 Fix the test 2024-01-27 13:21:13 +02:00
Heikki Linnakangas
b308be20df Fix typos and formatting in test, per 'ruff' 2024-01-27 13:21:13 +02:00
Heikki Linnakangas
a2d08cfc97 Fix formatting 2024-01-27 13:21:13 +02:00
Heikki Linnakangas
4ee11d9dfc Retroactively fix the nextXid on a known broken timeline
This one particular timeline in production hit the nextXid bug. Add a
one-off hack that will fix the nextXid on that particular timeline.
2024-01-27 13:21:13 +02:00
35 changed files with 934 additions and 1712 deletions

View File

@@ -508,7 +508,7 @@ jobs:
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
TEST_RESULT_CONNSTR: "${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}"
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: std-fs
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring
# XXX: no coverage data handling here, since benchmarks are run on release builds,
# while coverage is currently collected for the debug ones

20
Cargo.lock generated
View File

@@ -2736,12 +2736,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linux-raw-sys"
version = "0.1.4"
@@ -2838,9 +2832,6 @@ dependencies = [
"libc",
"once_cell",
"prometheus",
"rand 0.8.5",
"rand_distr",
"twox-hash",
"workspace_hack",
]
@@ -3066,7 +3057,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@@ -4181,16 +4171,6 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand 0.8.5",
]
[[package]]
name = "rand_hc"
version = "0.2.0"

View File

@@ -165,7 +165,6 @@ tracing = "0.1"
tracing-error = "0.2.0"
tracing-opentelemetry = "0.20.0"
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
twox-hash = { version = "1.6.3", default-features = false }
url = "2.2"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"

View File

@@ -53,7 +53,6 @@ RUN set -e \
--bin pagectl \
--bin safekeeper \
--bin storage_broker \
--bin attachment_service \
--bin proxy \
--bin neon_local \
--locked --release \
@@ -81,7 +80,6 @@ COPY --from=build --chown=neon:neon /home/nonroot/target/release/pageserver
COPY --from=build --chown=neon:neon /home/nonroot/target/release/pagectl /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/safekeeper /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_broker /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/attachment_service /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/proxy /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin

View File

@@ -520,7 +520,8 @@ RUN apt-get update && \
libboost-regex1.74-dev \
libboost-serialization1.74-dev \
libboost-system1.74-dev \
libeigen3-dev
libeigen3-dev \
libfreetype6-dev
ENV PATH "/usr/local/pgsql/bin/:/usr/local/pgsql/:$PATH"
RUN wget https://github.com/rdkit/rdkit/archive/refs/tags/Release_2023_03_3.tar.gz -O rdkit.tar.gz && \
@@ -546,7 +547,6 @@ RUN wget https://github.com/rdkit/rdkit/archive/refs/tags/Release_2023_03_3.tar.
-D PostgreSQL_LIBRARY_DIR=`pg_config --libdir` \
-D RDK_INSTALL_INTREE=OFF \
-D RDK_INSTALL_COMIC_FONTS=OFF \
-D RDK_BUILD_FREETYPE_SUPPORT=OFF \
-D CMAKE_BUILD_TYPE=Release \
. && \
make -j $(getconf _NPROCESSORS_ONLN) && \
@@ -901,7 +901,7 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
# libxml2, libxslt1.1 for xml2
# libzstd1 for zstd
# libboost* for rdkit
# libboost*, libfreetype6, and zlib1g for rdkit
# ca-certificates for communicating with s3 by compute_ctl
RUN apt update && \
apt install --no-install-recommends -y \
@@ -914,6 +914,7 @@ RUN apt update && \
libboost-serialization1.74.0 \
libboost-system1.74.0 \
libossp-uuid16 \
libfreetype6 \
libgeos-c1v5 \
libgdal28 \
libproj19 \
@@ -925,6 +926,7 @@ RUN apt update && \
libcurl4-openssl-dev \
locales \
procps \
zlib1g \
ca-certificates && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
localedef -i en_US -c -f UTF-8 -A /usr/share/locale/locale.alias en_US.UTF-8

View File

@@ -9,10 +9,5 @@ prometheus.workspace = true
libc.workspace = true
once_cell.workspace = true
chrono.workspace = true
twox-hash.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
rand = "0.8"
rand_distr = "0.4.3"

View File

@@ -1,523 +0,0 @@
//! HyperLogLog is an algorithm for the count-distinct problem,
//! approximating the number of distinct elements in a multiset.
//! Calculating the exact cardinality of the distinct elements
//! of a multiset requires an amount of memory proportional to
//! the cardinality, which is impractical for very large data sets.
//! Probabilistic cardinality estimators, such as the HyperLogLog algorithm,
//! use significantly less memory than this, but can only approximate the cardinality.
use std::{
collections::HashMap,
hash::{BuildHasher, BuildHasherDefault, Hash, Hasher},
sync::{atomic::AtomicU8, Arc, RwLock},
};
use prometheus::{
core::{self, Describer},
proto, Opts,
};
use twox_hash::xxh3;
/// Create an [`HyperLogLogVec`] and registers to default registry.
#[macro_export(local_inner_macros)]
macro_rules! register_hll_vec {
($N:literal, $OPTS:expr, $LABELS_NAMES:expr $(,)?) => {{
let hll_vec = $crate::HyperLogLogVec::<$N>::new($OPTS, $LABELS_NAMES).unwrap();
$crate::register(Box::new(hll_vec.clone())).map(|_| hll_vec)
}};
($N:literal, $NAME:expr, $HELP:expr, $LABELS_NAMES:expr $(,)?) => {{
$crate::register_hll_vec!($N, $crate::opts!($NAME, $HELP), $LABELS_NAMES)
}};
}
/// Create an [`HyperLogLog`] and registers to default registry.
#[macro_export(local_inner_macros)]
macro_rules! register_hll {
($N:literal, $OPTS:expr $(,)?) => {{
let hll = $crate::HyperLogLog::<$N>::with_opts($OPTS).unwrap();
$crate::register(Box::new(hll.clone())).map(|_| hll)
}};
($N:literal, $NAME:expr, $HELP:expr $(,)?) => {{
$crate::register_hll!($N, $crate::opts!($NAME, $HELP), $LABELS_NAMES)
}};
}
/// HLL is a probabilistic cardinality measure.
///
/// How to use this time-series for a metric name `my_metrics_total_hll`:
///
/// ```promql
/// # harmonic mean
/// 1 / (
/// sum (
/// 2 ^ -(
/// # HLL merge operation
/// max (my_metrics_total_hll{}) by (hll_shard, other_labels...)
/// )
/// ) without (hll_shard)
/// )
/// * alpha
/// * shards_count
/// * shards_count
/// ```
///
/// If you want an estimate over time, you can use the following query:
///
/// ```promql
/// # harmonic mean
/// 1 / (
/// sum (
/// 2 ^ -(
/// # HLL merge operation
/// max (
/// max_over_time(my_metrics_total_hll{}[$__rate_interval])
/// ) by (hll_shard, other_labels...)
/// )
/// ) without (hll_shard)
/// )
/// * alpha
/// * shards_count
/// * shards_count
/// ```
///
/// In the case of low cardinality, you might want to use the linear counting approximation:
///
/// ```promql
/// # LinearCounting(m, V) = m log (m / V)
/// shards_count * ln(shards_count /
/// # calculate V = how many shards contain a 0
/// count(max (proxy_connecting_endpoints{}) by (hll_shard, protocol) == 0) without (hll_shard)
/// )
/// ```
///
/// See <https://en.wikipedia.org/wiki/HyperLogLog#Practical_considerations> for estimates on alpha
#[derive(Clone)]
pub struct HyperLogLogVec<const N: usize> {
core: Arc<HyperLogLogVecCore<N>>,
}
struct HyperLogLogVecCore<const N: usize> {
pub children: RwLock<HashMap<u64, HyperLogLog<N>, BuildHasherDefault<xxh3::Hash64>>>,
pub desc: core::Desc,
pub opts: Opts,
}
impl<const N: usize> core::Collector for HyperLogLogVec<N> {
fn desc(&self) -> Vec<&core::Desc> {
vec![&self.core.desc]
}
fn collect(&self) -> Vec<proto::MetricFamily> {
let mut m = proto::MetricFamily::default();
m.set_name(self.core.desc.fq_name.clone());
m.set_help(self.core.desc.help.clone());
m.set_field_type(proto::MetricType::GAUGE);
let mut metrics = Vec::new();
for child in self.core.children.read().unwrap().values() {
child.core.collect_into(&mut metrics);
}
m.set_metric(metrics);
vec![m]
}
}
impl<const N: usize> HyperLogLogVec<N> {
/// Create a new [`HyperLogLogVec`] based on the provided
/// [`Opts`] and partitioned by the given label names. At least one label name must be
/// provided.
pub fn new(opts: Opts, label_names: &[&str]) -> prometheus::Result<Self> {
assert!(N.is_power_of_two());
let variable_names = label_names.iter().map(|s| (*s).to_owned()).collect();
let opts = opts.variable_labels(variable_names);
let desc = opts.describe()?;
let v = HyperLogLogVecCore {
children: RwLock::new(HashMap::default()),
desc,
opts,
};
Ok(Self { core: Arc::new(v) })
}
/// `get_metric_with_label_values` returns the [`HyperLogLog<P>`] for the given slice
/// of label values (same order as the VariableLabels in Desc). If that combination of
/// label values is accessed for the first time, a new [`HyperLogLog<P>`] is created.
///
/// An error is returned if the number of label values is not the same as the
/// number of VariableLabels in Desc.
pub fn get_metric_with_label_values(
&self,
vals: &[&str],
) -> prometheus::Result<HyperLogLog<N>> {
self.core.get_metric_with_label_values(vals)
}
/// `with_label_values` works as `get_metric_with_label_values`, but panics if an error
/// occurs.
pub fn with_label_values(&self, vals: &[&str]) -> HyperLogLog<N> {
self.get_metric_with_label_values(vals).unwrap()
}
}
impl<const N: usize> HyperLogLogVecCore<N> {
pub fn get_metric_with_label_values(
&self,
vals: &[&str],
) -> prometheus::Result<HyperLogLog<N>> {
let h = self.hash_label_values(vals)?;
if let Some(metric) = self.children.read().unwrap().get(&h).cloned() {
return Ok(metric);
}
self.get_or_create_metric(h, vals)
}
pub(crate) fn hash_label_values(&self, vals: &[&str]) -> prometheus::Result<u64> {
if vals.len() != self.desc.variable_labels.len() {
return Err(prometheus::Error::InconsistentCardinality {
expect: self.desc.variable_labels.len(),
got: vals.len(),
});
}
let mut h = xxh3::Hash64::default();
for val in vals {
h.write(val.as_bytes());
}
Ok(h.finish())
}
fn get_or_create_metric(
&self,
hash: u64,
label_values: &[&str],
) -> prometheus::Result<HyperLogLog<N>> {
let mut children = self.children.write().unwrap();
// Check exist first.
if let Some(metric) = children.get(&hash).cloned() {
return Ok(metric);
}
let metric = HyperLogLog::with_opts_and_label_values(&self.opts, label_values)?;
children.insert(hash, metric.clone());
Ok(metric)
}
}
/// HLL is a probabilistic cardinality measure.
///
/// How to use this time-series for a metric name `my_metrics_total_hll`:
///
/// ```promql
/// # harmonic mean
/// 1 / (
/// sum (
/// 2 ^ -(
/// # HLL merge operation
/// max (my_metrics_total_hll{}) by (hll_shard, other_labels...)
/// )
/// ) without (hll_shard)
/// )
/// * alpha
/// * shards_count
/// * shards_count
/// ```
///
/// If you want an estimate over time, you can use the following query:
///
/// ```promql
/// # harmonic mean
/// 1 / (
/// sum (
/// 2 ^ -(
/// # HLL merge operation
/// max (
/// max_over_time(my_metrics_total_hll{}[$__rate_interval])
/// ) by (hll_shard, other_labels...)
/// )
/// ) without (hll_shard)
/// )
/// * alpha
/// * shards_count
/// * shards_count
/// ```
///
/// In the case of low cardinality, you might want to use the linear counting approximation:
///
/// ```promql
/// # LinearCounting(m, V) = m log (m / V)
/// shards_count * ln(shards_count /
/// # calculate V = how many shards contain a 0
/// count(max (proxy_connecting_endpoints{}) by (hll_shard, protocol) == 0) without (hll_shard)
/// )
/// ```
///
/// See <https://en.wikipedia.org/wiki/HyperLogLog#Practical_considerations> for estimates on alpha
#[derive(Clone)]
pub struct HyperLogLog<const N: usize> {
core: Arc<HyperLogLogCore<N>>,
}
impl<const N: usize> HyperLogLog<N> {
/// Create a [`HyperLogLog`] with the `name` and `help` arguments.
pub fn new<S1: Into<String>, S2: Into<String>>(name: S1, help: S2) -> prometheus::Result<Self> {
assert!(N.is_power_of_two());
let opts = Opts::new(name, help);
Self::with_opts(opts)
}
/// Create a [`HyperLogLog`] with the `opts` options.
pub fn with_opts(opts: Opts) -> prometheus::Result<Self> {
Self::with_opts_and_label_values(&opts, &[])
}
fn with_opts_and_label_values(opts: &Opts, label_values: &[&str]) -> prometheus::Result<Self> {
let desc = opts.describe()?;
let labels = make_label_pairs(&desc, label_values)?;
let v = HyperLogLogCore {
shards: [0; N].map(AtomicU8::new),
desc,
labels,
};
Ok(Self { core: Arc::new(v) })
}
pub fn measure(&self, item: &impl Hash) {
// changing the hasher will break compatibility with previous measurements.
self.record(BuildHasherDefault::<xxh3::Hash64>::default().hash_one(item));
}
fn record(&self, hash: u64) {
let p = N.ilog2() as u8;
let j = hash & (N as u64 - 1);
let rho = (hash >> p).leading_zeros() as u8 + 1 - p;
self.core.shards[j as usize].fetch_max(rho, std::sync::atomic::Ordering::Relaxed);
}
}
struct HyperLogLogCore<const N: usize> {
shards: [AtomicU8; N],
desc: core::Desc,
labels: Vec<proto::LabelPair>,
}
impl<const N: usize> core::Collector for HyperLogLog<N> {
fn desc(&self) -> Vec<&core::Desc> {
vec![&self.core.desc]
}
fn collect(&self) -> Vec<proto::MetricFamily> {
let mut m = proto::MetricFamily::default();
m.set_name(self.core.desc.fq_name.clone());
m.set_help(self.core.desc.help.clone());
m.set_field_type(proto::MetricType::GAUGE);
let mut metrics = Vec::new();
self.core.collect_into(&mut metrics);
m.set_metric(metrics);
vec![m]
}
}
impl<const N: usize> HyperLogLogCore<N> {
fn collect_into(&self, metrics: &mut Vec<proto::Metric>) {
self.shards.iter().enumerate().for_each(|(i, x)| {
let mut shard_label = proto::LabelPair::default();
shard_label.set_name("hll_shard".to_owned());
shard_label.set_value(format!("{i}"));
// We reset the counter to 0 so we can perform a cardinality measure over any time slice in prometheus.
// This seems like it would be a race condition,
// but HLL is not impacted by a write in one shard happening in between.
// This is because in PromQL we will be implementing a harmonic mean of all buckets.
// we will also merge samples in a time series using `max by (hll_shard)`.
// TODO: maybe we shouldn't reset this on every collect, instead, only after a time window.
// this would mean that a dev port-forwarding the metrics url won't break the sampling.
let v = x.swap(0, std::sync::atomic::Ordering::Relaxed);
let mut m = proto::Metric::default();
let mut c = proto::Gauge::default();
c.set_value(v as f64);
m.set_gauge(c);
let mut labels = Vec::with_capacity(self.labels.len() + 1);
labels.extend_from_slice(&self.labels);
labels.push(shard_label);
m.set_label(labels);
metrics.push(m);
})
}
}
fn make_label_pairs(
desc: &core::Desc,
label_values: &[&str],
) -> prometheus::Result<Vec<proto::LabelPair>> {
if desc.variable_labels.len() != label_values.len() {
return Err(prometheus::Error::InconsistentCardinality {
expect: desc.variable_labels.len(),
got: label_values.len(),
});
}
let total_len = desc.variable_labels.len() + desc.const_label_pairs.len();
if total_len == 0 {
return Ok(vec![]);
}
if desc.variable_labels.is_empty() {
return Ok(desc.const_label_pairs.clone());
}
let mut label_pairs = Vec::with_capacity(total_len);
for (i, n) in desc.variable_labels.iter().enumerate() {
let mut label_pair = proto::LabelPair::default();
label_pair.set_name(n.clone());
label_pair.set_value(label_values[i].to_owned());
label_pairs.push(label_pair);
}
for label_pair in &desc.const_label_pairs {
label_pairs.push(label_pair.clone());
}
label_pairs.sort();
Ok(label_pairs)
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use prometheus::{proto, Opts};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::{Distribution, Zipf};
use crate::HyperLogLogVec;
fn collect(hll: &HyperLogLogVec<32>) -> Vec<proto::Metric> {
let mut metrics = vec![];
hll.core
.children
.read()
.unwrap()
.values()
.for_each(|c| c.core.collect_into(&mut metrics));
metrics
}
fn get_cardinality(metrics: &[proto::Metric], filter: impl Fn(&proto::Metric) -> bool) -> f64 {
let mut buckets = [0.0; 32];
for metric in metrics.chunks_exact(32) {
if filter(&metric[0]) {
for (i, m) in metric.iter().enumerate() {
buckets[i] = f64::max(buckets[i], m.get_gauge().get_value());
}
}
}
buckets
.into_iter()
.map(|f| 2.0f64.powf(-f))
.sum::<f64>()
.recip()
* 0.697
* 32.0
* 32.0
}
fn test_cardinality(n: usize, dist: impl Distribution<f64>) -> ([usize; 3], [f64; 3]) {
let hll = HyperLogLogVec::<32>::new(Opts::new("foo", "bar"), &["x"]).unwrap();
let mut iter = StdRng::seed_from_u64(0x2024_0112).sample_iter(dist);
let mut set_a = HashSet::new();
let mut set_b = HashSet::new();
for x in iter.by_ref().take(n) {
set_a.insert(x.to_bits());
hll.with_label_values(&["a"]).measure(&x.to_bits());
}
for x in iter.by_ref().take(n) {
set_b.insert(x.to_bits());
hll.with_label_values(&["b"]).measure(&x.to_bits());
}
let merge = &set_a | &set_b;
let metrics = collect(&hll);
let len = get_cardinality(&metrics, |_| true);
let len_a = get_cardinality(&metrics, |l| l.get_label()[0].get_value() == "a");
let len_b = get_cardinality(&metrics, |l| l.get_label()[0].get_value() == "b");
([merge.len(), set_a.len(), set_b.len()], [len, len_a, len_b])
}
#[test]
fn test_cardinality_small() {
let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap());
assert_eq!(actual, [46, 30, 32]);
assert!(51.3 < estimate[0] && estimate[0] < 51.4);
assert!(44.0 < estimate[1] && estimate[1] < 44.1);
assert!(39.0 < estimate[2] && estimate[2] < 39.1);
}
#[test]
fn test_cardinality_medium() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap());
assert_eq!(actual, [2529, 1618, 1629]);
assert!(2309.1 < estimate[0] && estimate[0] < 2309.2);
assert!(1566.6 < estimate[1] && estimate[1] < 1566.7);
assert!(1629.5 < estimate[2] && estimate[2] < 1629.6);
}
#[test]
fn test_cardinality_large() {
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap());
assert_eq!(actual, [129077, 79579, 79630]);
assert!(126067.2 < estimate[0] && estimate[0] < 126067.3);
assert!(83076.8 < estimate[1] && estimate[1] < 83076.9);
assert!(64251.2 < estimate[2] && estimate[2] < 64251.3);
}
#[test]
fn test_cardinality_small2() {
let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap());
assert_eq!(actual, [92, 58, 60]);
assert!(116.1 < estimate[0] && estimate[0] < 116.2);
assert!(81.7 < estimate[1] && estimate[1] < 81.8);
assert!(69.3 < estimate[2] && estimate[2] < 69.4);
}
#[test]
fn test_cardinality_medium2() {
let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap());
assert_eq!(actual, [8201, 5131, 5051]);
assert!(6846.4 < estimate[0] && estimate[0] < 6846.5);
assert!(5239.1 < estimate[1] && estimate[1] < 5239.2);
assert!(4292.8 < estimate[2] && estimate[2] < 4292.9);
}
#[test]
fn test_cardinality_large2() {
let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap());
assert_eq!(actual, [777847, 482069, 482246]);
assert!(699437.4 < estimate[0] && estimate[0] < 699437.5);
assert!(374948.9 < estimate[1] && estimate[1] < 374949.0);
assert!(434609.7 < estimate[2] && estimate[2] < 434609.8);
}
}

View File

@@ -28,9 +28,7 @@ use prometheus::{Registry, Result};
pub mod launch_timestamp;
mod wrappers;
pub use wrappers::{CountedReader, CountedWriter};
mod hll;
pub mod metric_vec_duration;
pub use hll::{HyperLogLog, HyperLogLogVec};
pub type UIntGauge = GenericGauge<AtomicU64>;
pub type UIntGaugeVec = GenericGaugeVec<AtomicU64>;

View File

@@ -646,7 +646,7 @@ impl RemoteStorage for S3Bucket {
let timestamp = DateTime::from(timestamp);
let done_if_after = DateTime::from(done_if_after);
tracing::info!("Target time: {timestamp:?}, done_if_after {done_if_after:?}");
tracing::trace!("Target time: {timestamp:?}, done_if_after {done_if_after:?}");
// get the passed prefix or if it is not set use prefix_in_bucket value
let prefix = prefix
@@ -657,67 +657,40 @@ impl RemoteStorage for S3Bucket {
let max_retries = 10;
let is_permanent = |_e: &_| false;
let mut key_marker = None;
let mut version_id_marker = None;
let mut versions_and_deletes = Vec::new();
let list = backoff::retry(
|| async {
Ok(self
.client
.list_object_versions()
.bucket(self.bucket_name.clone())
.set_prefix(prefix.clone())
.send()
.await?)
},
is_permanent,
warn_threshold,
max_retries,
"listing object versions for time_travel_recover",
backoff::Cancel::new(cancel.clone(), || anyhow!("Cancelled")),
)
.await?;
loop {
let response = backoff::retry(
|| async {
Ok(self
.client
.list_object_versions()
.bucket(self.bucket_name.clone())
.set_prefix(prefix.clone())
.set_key_marker(key_marker.clone())
.set_version_id_marker(version_id_marker.clone())
.send()
.await?)
},
is_permanent,
warn_threshold,
max_retries,
"listing object versions for time_travel_recover",
backoff::Cancel::new(cancel.clone(), || anyhow!("Cancelled")),
)
.await?;
tracing::trace!(
" Got List response version_id_marker={:?}, key_marker={:?}",
response.version_id_marker,
response.key_marker
);
let versions = response.versions.unwrap_or_default();
let delete_markers = response.delete_markers.unwrap_or_default();
let new_versions = versions.into_iter().map(VerOrDelete::Version);
let new_deletes = delete_markers.into_iter().map(VerOrDelete::DeleteMarker);
let new_versions_and_deletes = new_versions.chain(new_deletes);
versions_and_deletes.extend(new_versions_and_deletes);
fn none_if_empty(v: Option<String>) -> Option<String> {
v.filter(|v| !v.is_empty())
}
version_id_marker = none_if_empty(response.next_version_id_marker);
key_marker = none_if_empty(response.next_key_marker);
if version_id_marker.is_none() {
// The final response is not supposed to be truncated
if response.is_truncated.unwrap_or_default() {
anyhow::bail!(
"Received truncated ListObjectVersions response for prefix={prefix:?}"
);
}
break;
}
if list.is_truncated().unwrap_or_default() {
anyhow::bail!("Received truncated ListObjectVersions response for prefix={prefix:?}");
}
// Work on the list of references instead of the objects directly,
// otherwise we get lifetime errors in the sort_by_key call below.
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
let mut versions_deletes = list
.versions()
.iter()
.map(VerOrDelete::Version)
.chain(list.delete_markers().iter().map(VerOrDelete::DeleteMarker))
.collect::<Vec<_>>();
versions_and_deletes.sort_by_key(|vd| (vd.key(), vd.last_modified()));
versions_deletes.sort_by_key(|vd| (vd.key(), vd.last_modified()));
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
for vd in &versions_and_deletes {
for vd in versions_deletes {
let last_modified = vd.last_modified();
let version_id = vd.version_id();
let key = vd.key();
@@ -838,25 +811,25 @@ fn start_measuring_requests(
})
}
enum VerOrDelete {
Version(ObjectVersion),
DeleteMarker(DeleteMarkerEntry),
enum VerOrDelete<'a> {
Version(&'a ObjectVersion),
DeleteMarker(&'a DeleteMarkerEntry),
}
impl VerOrDelete {
fn last_modified(&self) -> Option<&DateTime> {
impl<'a> VerOrDelete<'a> {
fn last_modified(&self) -> Option<&'a DateTime> {
match self {
VerOrDelete::Version(v) => v.last_modified(),
VerOrDelete::DeleteMarker(v) => v.last_modified(),
}
}
fn version_id(&self) -> Option<&str> {
fn version_id(&self) -> Option<&'a str> {
match self {
VerOrDelete::Version(v) => v.version_id(),
VerOrDelete::DeleteMarker(v) => v.version_id(),
}
}
fn key(&self) -> Option<&str> {
fn key(&self) -> Option<&'a str> {
match self {
VerOrDelete::Version(v) => v.key(),
VerOrDelete::DeleteMarker(v) => v.key(),

View File

@@ -97,86 +97,23 @@ pub enum EvictionOrder {
/// Order the layers to be evicted by how recently they have been accessed relatively within
/// the set of resident layers of a tenant.
///
/// This strategy will evict layers more fairly but is untested.
RelativeAccessed {
/// Determines if the tenant with most layers should lose first.
///
/// Having this enabled is currently the only reasonable option, because the order in which
/// we read tenants is deterministic. If we find the need to use this as `false`, we need
/// to ensure nondeterminism by adding in a random number to break the
/// `relative_last_activity==0.0` ties.
#[serde(default = "default_highest_layer_count_loses_first")]
#[serde(default)]
highest_layer_count_loses_first: bool,
},
}
fn default_highest_layer_count_loses_first() -> bool {
true
}
impl EvictionOrder {
fn sort(&self, candidates: &mut [(MinResidentSizePartition, EvictionCandidate)]) {
use EvictionOrder::*;
/// Return true, if with [`Self::RelativeAccessed`] order the tenants with the highest layer
/// counts should be the first ones to have their layers evicted.
fn highest_layer_count_loses_first(&self) -> bool {
match self {
AbsoluteAccessed => {
candidates.sort_unstable_by_key(|(partition, candidate)| {
(*partition, candidate.last_activity_ts)
});
}
RelativeAccessed { .. } => candidates.sort_unstable_by_key(|(partition, candidate)| {
(*partition, candidate.relative_last_activity)
}),
}
}
/// Called to fill in the [`EvictionCandidate::relative_last_activity`] while iterating tenants
/// layers in **most** recently used order.
fn relative_last_activity(&self, total: usize, index: usize) -> finite_f32::FiniteF32 {
use EvictionOrder::*;
match self {
AbsoluteAccessed => finite_f32::FiniteF32::ZERO,
RelativeAccessed {
EvictionOrder::AbsoluteAccessed => false,
EvictionOrder::RelativeAccessed {
highest_layer_count_loses_first,
} => {
// keeping the -1 or not decides if every tenant should lose their least recently accessed
// layer OR if this should happen in the order of having highest layer count:
let fudge = if *highest_layer_count_loses_first {
// relative_last_activity vs. tenant layer count:
// - 0.1..=1.0 (10 layers)
// - 0.01..=1.0 (100 layers)
// - 0.001..=1.0 (1000 layers)
//
// leading to evicting less of the smallest tenants.
0
} else {
// use full 0.0..=1.0 range, which means even the smallest tenants could always lose a
// layer. the actual ordering is unspecified: for 10k tenants on a pageserver it could
// be that less than 10k layer evictions is enough, so we would not need to evict from
// all tenants.
//
// as the tenant ordering is now deterministic this could hit the same tenants
// disproportionetly on multiple invocations. alternative could be to remember how many
// layers did we evict last time from this tenant, and inject that as an additional
// fudge here.
1
};
let total = total.checked_sub(fudge).filter(|&x| x > 1).unwrap_or(1);
let divider = total as f32;
// most recently used is always (total - 0) / divider == 1.0
// least recently used depends on the fudge:
// - (total - 1) - (total - 1) / total => 0 / total
// - total - (total - 1) / total => 1 / total
let distance = (total - index) as f32;
finite_f32::FiniteF32::try_from_normalized(distance / divider)
.unwrap_or_else(|val| {
tracing::warn!(%fudge, "calculated invalid relative_last_activity for i={index}, total={total}: {val}");
finite_f32::FiniteF32::ZERO
})
}
} => *highest_layer_count_loses_first,
}
}
}
@@ -452,6 +389,52 @@ pub(crate) async fn disk_usage_eviction_task_iteration_impl<U: Usage>(
let selection = select_victims(&candidates, usage_pre);
let mut candidates = candidates;
let selection = if matches!(eviction_order, EvictionOrder::RelativeAccessed { .. }) {
// we currently have the layers ordered by AbsoluteAccessed so that we can get the summary
// for comparison here. this is a temporary measure to develop alternatives.
use std::fmt::Write;
let mut summary_buf = String::with_capacity(256);
{
let absolute_summary = candidates
.iter()
.take(selection.amount)
.map(|(_, candidate)| candidate)
.collect::<summary::EvictionSummary>();
write!(summary_buf, "{absolute_summary}").expect("string grows");
info!("absolute accessed selection summary: {summary_buf}");
}
candidates.sort_unstable_by_key(|(partition, candidate)| {
(*partition, candidate.relative_last_activity)
});
let selection = select_victims(&candidates, usage_pre);
{
summary_buf.clear();
let relative_summary = candidates
.iter()
.take(selection.amount)
.map(|(_, candidate)| candidate)
.collect::<summary::EvictionSummary>();
write!(summary_buf, "{relative_summary}").expect("string grows");
info!("relative accessed selection summary: {summary_buf}");
}
selection
} else {
selection
};
let (evicted_amount, usage_planned) = selection.into_amount_and_planned();
// phase2: evict layers
@@ -852,12 +835,54 @@ async fn collect_eviction_candidates(
.sort_unstable_by_key(|layer_info| std::cmp::Reverse(layer_info.last_activity_ts));
let mut cumsum: i128 = 0;
let total = tenant_candidates.len();
// keeping the -1 or not decides if every tenant should lose their least recently accessed
// layer OR if this should happen in the order of having highest layer count:
let fudge = if eviction_order.highest_layer_count_loses_first() {
// relative_age vs. tenant layer count:
// - 0.1..=1.0 (10 layers)
// - 0.01..=1.0 (100 layers)
// - 0.001..=1.0 (1000 layers)
//
// leading to evicting less of the smallest tenants.
0
} else {
// use full 0.0..=1.0 range, which means even the smallest tenants could always lose a
// layer. the actual ordering is unspecified: for 10k tenants on a pageserver it could
// be that less than 10k layer evictions is enough, so we would not need to evict from
// all tenants.
//
// as the tenant ordering is now deterministic this could hit the same tenants
// disproportionetly on multiple invocations. alternative could be to remember how many
// layers did we evict last time from this tenant, and inject that as an additional
// fudge here.
1
};
let total = tenant_candidates
.len()
.checked_sub(fudge)
.filter(|&x| x > 0)
// support 0 or 1 resident layer tenants as well
.unwrap_or(1);
let divider = total as f32;
for (i, mut candidate) in tenant_candidates.into_iter().enumerate() {
// as we iterate this reverse sorted list, the most recently accessed layer will always
// be 1.0; this is for us to evict it last.
candidate.relative_last_activity = eviction_order.relative_last_activity(total, i);
candidate.relative_last_activity = if matches!(
eviction_order,
EvictionOrder::RelativeAccessed { .. }
) {
// another possibility: use buckets, like (256.0 * relative_last_activity) as u8 or
// similarly for u16. unsure how it would help.
finite_f32::FiniteF32::try_from_normalized((total - i) as f32 / divider)
.unwrap_or_else(|val| {
tracing::warn!(%fudge, "calculated invalid relative_last_activity for i={i}, total={total}: {val}");
finite_f32::FiniteF32::ZERO
})
} else {
finite_f32::FiniteF32::ZERO
};
let partition = if cumsum > min_resident_size as i128 {
MinResidentSizePartition::Above
@@ -902,7 +927,10 @@ async fn collect_eviction_candidates(
debug_assert!(MinResidentSizePartition::Above < MinResidentSizePartition::Below,
"as explained in the function's doc comment, layers that aren't in the tenant's min_resident_size are evicted first");
eviction_order.sort(&mut candidates);
// always behave as if AbsoluteAccessed was selected. if RelativeAccessed is in use, we
// will sort later by candidate.relative_last_activity to get compare evictions.
candidates
.sort_unstable_by_key(|(partition, candidate)| (*partition, candidate.last_activity_ts));
Ok(EvictionCandidates::Finished(candidates))
}
@@ -1042,12 +1070,6 @@ pub(crate) mod finite_f32 {
}
}
impl From<FiniteF32> for f32 {
fn from(value: FiniteF32) -> f32 {
value.0
}
}
impl FiniteF32 {
pub const ZERO: FiniteF32 = FiniteF32(0.0);
@@ -1060,9 +1082,136 @@ pub(crate) mod finite_f32 {
Err(value)
}
}
}
}
pub fn into_inner(self) -> f32 {
self.into()
mod summary {
use super::finite_f32::FiniteF32;
use super::{EvictionCandidate, LayerCount};
use pageserver_api::shard::TenantShardId;
use std::collections::{BTreeMap, HashMap};
use std::time::SystemTime;
#[derive(Debug, Default)]
pub(super) struct EvictionSummary {
evicted_per_tenant: HashMap<TenantShardId, LayerCount>,
total: LayerCount,
last_absolute: Option<SystemTime>,
last_relative: Option<FiniteF32>,
}
impl<'a> FromIterator<&'a EvictionCandidate> for EvictionSummary {
fn from_iter<T: IntoIterator<Item = &'a EvictionCandidate>>(iter: T) -> Self {
let mut summary = EvictionSummary::default();
for item in iter {
let counts = summary
.evicted_per_tenant
.entry(*item.layer.get_tenant_shard_id())
.or_default();
let sz = item.layer.get_file_size();
counts.file_sizes += sz;
counts.count += 1;
summary.total.file_sizes += sz;
summary.total.count += 1;
summary.last_absolute = Some(item.last_activity_ts);
summary.last_relative = Some(item.relative_last_activity);
}
summary
}
}
struct SiBytesAmount(u64);
impl std::fmt::Display for SiBytesAmount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0 < 1024 {
return write!(f, "{}B", self.0);
}
let mut tmp = self.0;
let mut ch = 0;
let suffixes = b"KMGTPE";
while tmp > 1024 * 1024 && ch < suffixes.len() - 1 {
tmp /= 1024;
ch += 1;
}
let ch = suffixes[ch] as char;
write!(f, "{:.1}{ch}iB", tmp as f64 / 1024.0)
}
}
impl std::fmt::Display for EvictionSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// wasteful, but it's for testing
let mut sorted: BTreeMap<usize, Vec<(TenantShardId, u64)>> = BTreeMap::new();
for (tenant_shard_id, count) in &self.evicted_per_tenant {
sorted
.entry(count.count)
.or_default()
.push((*tenant_shard_id, count.file_sizes));
}
let total_file_sizes = SiBytesAmount(self.total.file_sizes);
writeln!(
f,
"selected {} layers of {total_file_sizes} up to ({:?}, {:.2?}):",
self.total.count, self.last_absolute, self.last_relative,
)?;
for (count, per_tenant) in sorted.iter().rev().take(10) {
write!(f, "- {count} layers: ")?;
if per_tenant.len() < 3 {
for (i, (tenant_shard_id, bytes)) in per_tenant.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
let bytes = SiBytesAmount(*bytes);
write!(f, "{tenant_shard_id} ({bytes})")?;
}
} else {
let num_tenants = per_tenant.len();
let total_bytes = per_tenant.iter().map(|(_id, bytes)| bytes).sum::<u64>();
let total_bytes = SiBytesAmount(total_bytes);
let layers = num_tenants * count;
write!(
f,
"{num_tenants} tenants {total_bytes} in total {layers} layers",
)?;
}
writeln!(f)?;
}
if sorted.len() > 10 {
let (rem_count, rem_bytes) = sorted
.iter()
.rev()
.map(|(count, per_tenant)| {
(
count,
per_tenant.iter().map(|(_id, bytes)| bytes).sum::<u64>(),
)
})
.fold((0, 0), |acc, next| (acc.0 + next.0, acc.1 + next.1));
let rem_bytes = SiBytesAmount(rem_bytes);
writeln!(f, "- rest of tenants ({}) not shown ({rem_count} layers or {:.1}%, {rem_bytes} or {:.1}% bytes)", sorted.len() - 10, 100.0 * rem_count as f64 / self.total.count as f64, 100.0 * rem_bytes.0 as f64 / self.total.file_sizes as f64)?;
}
Ok(())
}
}
}
@@ -1187,40 +1336,3 @@ mod filesystem_level_usage {
assert!(!usage.has_pressure());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn relative_equal_bounds() {
let order = EvictionOrder::RelativeAccessed {
highest_layer_count_loses_first: false,
};
let len = 10;
let v = (0..len)
.map(|i| order.relative_last_activity(len, i).into_inner())
.collect::<Vec<_>>();
assert_eq!(v.first(), Some(&1.0));
assert_eq!(v.last(), Some(&0.0));
assert!(v.windows(2).all(|slice| slice[0] > slice[1]));
}
#[test]
fn relative_spare_bounds() {
let order = EvictionOrder::RelativeAccessed {
highest_layer_count_loses_first: true,
};
let len = 10;
let v = (0..len)
.map(|i| order.relative_last_activity(len, i).into_inner())
.collect::<Vec<_>>();
assert_eq!(v.first(), Some(&1.0));
assert_eq!(v.last(), Some(&0.1));
assert!(v.windows(2).all(|slice| slice[0] > slice[1]));
}
}

View File

@@ -51,10 +51,7 @@ use crate::keyspace::KeyPartitioning;
use crate::repository::Key;
use crate::tenant::storage_layer::InMemoryLayer;
use anyhow::Result;
use pageserver_api::keyspace::KeySpaceAccum;
use std::cmp::Ordering;
use std::collections::{BTreeMap, VecDeque};
use std::iter::Peekable;
use std::collections::VecDeque;
use std::ops::Range;
use std::sync::Arc;
use utils::lsn::Lsn;
@@ -147,221 +144,11 @@ impl Drop for BatchedUpdates<'_> {
}
/// Return value of LayerMap::search
#[derive(Eq, PartialEq, Debug)]
pub struct SearchResult {
pub layer: Arc<PersistentLayerDesc>,
pub lsn_floor: Lsn,
}
pub struct OrderedSearchResult(SearchResult);
impl Ord for OrderedSearchResult {
fn cmp(&self, other: &Self) -> Ordering {
self.0.lsn_floor.cmp(&other.0.lsn_floor)
}
}
impl PartialOrd for OrderedSearchResult {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for OrderedSearchResult {
fn eq(&self, other: &Self) -> bool {
self.0.lsn_floor == other.0.lsn_floor
}
}
impl Eq for OrderedSearchResult {}
pub struct RangeSearchResult {
pub found: BTreeMap<OrderedSearchResult, KeySpaceAccum>,
pub not_found: KeySpaceAccum,
}
impl RangeSearchResult {
fn new() -> Self {
Self {
found: BTreeMap::new(),
not_found: KeySpaceAccum::new(),
}
}
}
/// Collector for results of range search queries on the LayerMap.
/// It should be provided with two iterators for the delta and image coverage
/// that contain all the changes for layers which intersect the range.
struct RangeSearchCollector<Iter>
where
Iter: Iterator<Item = (i128, Option<Arc<PersistentLayerDesc>>)>,
{
delta_coverage: Peekable<Iter>,
image_coverage: Peekable<Iter>,
key_range: Range<Key>,
end_lsn: Lsn,
current_delta: Option<Arc<PersistentLayerDesc>>,
current_image: Option<Arc<PersistentLayerDesc>>,
result: RangeSearchResult,
}
#[derive(Debug)]
enum NextLayerType {
Delta(i128),
Image(i128),
Both(i128),
}
impl NextLayerType {
fn next_change_at_key(&self) -> Key {
match self {
NextLayerType::Delta(at) => Key::from_i128(*at),
NextLayerType::Image(at) => Key::from_i128(*at),
NextLayerType::Both(at) => Key::from_i128(*at),
}
}
}
impl<Iter> RangeSearchCollector<Iter>
where
Iter: Iterator<Item = (i128, Option<Arc<PersistentLayerDesc>>)>,
{
fn new(
key_range: Range<Key>,
end_lsn: Lsn,
delta_coverage: Iter,
image_coverage: Iter,
) -> Self {
Self {
delta_coverage: delta_coverage.peekable(),
image_coverage: image_coverage.peekable(),
key_range,
end_lsn,
current_delta: None,
current_image: None,
result: RangeSearchResult::new(),
}
}
/// Run the collector. Collection is implemented via a two pointer algorithm.
/// One pointer tracks the start of the current range and the other tracks
/// the beginning of the next range which will overlap with the next change
/// in coverage across both image and delta.
fn collect(mut self) -> RangeSearchResult {
let next_layer_type = self.choose_next_layer_type();
let mut current_range_start = match next_layer_type {
None => {
// No changes for the range
self.pad_range(self.key_range.clone());
return self.result;
}
Some(layer_type) if self.key_range.end <= layer_type.next_change_at_key() => {
// Changes only after the end of the range
self.pad_range(self.key_range.clone());
return self.result;
}
Some(layer_type) => {
// Changes for the range exist. Record anything before the first
// coverage change as not found.
let coverage_start = layer_type.next_change_at_key();
let range_before = self.key_range.start..coverage_start;
self.pad_range(range_before);
self.advance(&layer_type);
coverage_start
}
};
while current_range_start < self.key_range.end {
let next_layer_type = self.choose_next_layer_type();
match next_layer_type {
Some(t) => {
let current_range_end = t.next_change_at_key();
self.add_range(current_range_start..current_range_end);
current_range_start = current_range_end;
self.advance(&t);
}
None => {
self.add_range(current_range_start..self.key_range.end);
current_range_start = self.key_range.end;
}
}
}
self.result
}
/// Mark a range as not found (i.e. no layers intersect it)
fn pad_range(&mut self, key_range: Range<Key>) {
if !key_range.is_empty() {
self.result.not_found.add_range(key_range);
}
}
/// Select the appropiate layer for the given range and update
/// the collector.
fn add_range(&mut self, covered_range: Range<Key>) {
let selected = LayerMap::select_layer(
self.current_delta.clone(),
self.current_image.clone(),
self.end_lsn,
);
match selected {
Some(search_result) => self
.result
.found
.entry(OrderedSearchResult(search_result))
.or_default()
.add_range(covered_range),
None => self.pad_range(covered_range),
}
}
/// Move to the next coverage change.
fn advance(&mut self, layer_type: &NextLayerType) {
match layer_type {
NextLayerType::Delta(_) => {
let (_, layer) = self.delta_coverage.next().unwrap();
self.current_delta = layer;
}
NextLayerType::Image(_) => {
let (_, layer) = self.image_coverage.next().unwrap();
self.current_image = layer;
}
NextLayerType::Both(_) => {
let (_, image_layer) = self.image_coverage.next().unwrap();
let (_, delta_layer) = self.delta_coverage.next().unwrap();
self.current_image = image_layer;
self.current_delta = delta_layer;
}
}
}
/// Pick the next coverage change: the one at the lesser key or both if they're alligned.
fn choose_next_layer_type(&mut self) -> Option<NextLayerType> {
let next_delta_at = self.delta_coverage.peek().map(|(key, _)| key);
let next_image_at = self.image_coverage.peek().map(|(key, _)| key);
match (next_delta_at, next_image_at) {
(None, None) => None,
(Some(next_delta_at), None) => Some(NextLayerType::Delta(*next_delta_at)),
(None, Some(next_image_at)) => Some(NextLayerType::Image(*next_image_at)),
(Some(next_delta_at), Some(next_image_at)) if next_image_at < next_delta_at => {
Some(NextLayerType::Image(*next_image_at))
}
(Some(next_delta_at), Some(next_image_at)) if next_delta_at < next_image_at => {
Some(NextLayerType::Delta(*next_delta_at))
}
(Some(next_delta_at), Some(_)) => Some(NextLayerType::Both(*next_delta_at)),
}
}
}
impl LayerMap {
///
/// Find the latest layer (by lsn.end) that covers the given
@@ -399,18 +186,7 @@ impl LayerMap {
let latest_delta = version.delta_coverage.query(key.to_i128());
let latest_image = version.image_coverage.query(key.to_i128());
Self::select_layer(latest_delta, latest_image, end_lsn)
}
fn select_layer(
delta_layer: Option<Arc<PersistentLayerDesc>>,
image_layer: Option<Arc<PersistentLayerDesc>>,
end_lsn: Lsn,
) -> Option<SearchResult> {
assert!(delta_layer.as_ref().map_or(true, |l| l.is_delta()));
assert!(image_layer.as_ref().map_or(true, |l| !l.is_delta()));
match (delta_layer, image_layer) {
match (latest_delta, latest_image) {
(None, None) => None,
(None, Some(image)) => {
let lsn_floor = image.get_lsn_range().start;
@@ -447,17 +223,6 @@ impl LayerMap {
}
}
pub fn range_search(&self, key_range: Range<Key>, end_lsn: Lsn) -> Option<RangeSearchResult> {
let version = self.historic.get().unwrap().get_version(end_lsn.0 - 1)?;
let raw_range = key_range.start.to_i128()..key_range.end.to_i128();
let delta_changes = version.delta_coverage.range_overlaps(&raw_range);
let image_changes = version.image_coverage.range_overlaps(&raw_range);
let collector = RangeSearchCollector::new(key_range, end_lsn, delta_changes, image_changes);
Some(collector.collect())
}
/// Start a batch of updates, applied on drop
pub fn batch_update(&mut self) -> BatchedUpdates<'_> {
BatchedUpdates { layer_map: self }
@@ -866,126 +631,3 @@ impl LayerMap {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct LayerDesc {
key_range: Range<Key>,
lsn_range: Range<Lsn>,
is_delta: bool,
}
fn create_layer_map(layers: Vec<LayerDesc>) -> LayerMap {
let mut layer_map = LayerMap::default();
for layer in layers {
layer_map.insert_historic_noflush(PersistentLayerDesc::new_test(
layer.key_range,
layer.lsn_range,
layer.is_delta,
));
}
layer_map.flush_updates();
layer_map
}
fn assert_range_search_result_eq(lhs: RangeSearchResult, rhs: RangeSearchResult) {
assert_eq!(lhs.not_found.to_keyspace(), rhs.not_found.to_keyspace());
let lhs: Vec<_> = lhs
.found
.into_iter()
.map(|(search_result, accum)| (search_result.0, accum.to_keyspace()))
.collect();
let rhs: Vec<_> = rhs
.found
.into_iter()
.map(|(search_result, accum)| (search_result.0, accum.to_keyspace()))
.collect();
assert_eq!(lhs, rhs);
}
fn brute_force_range_search(
layer_map: &LayerMap,
key_range: Range<Key>,
end_lsn: Lsn,
) -> RangeSearchResult {
let mut range_search_result = RangeSearchResult::new();
let mut key = key_range.start;
while key != key_range.end {
let res = layer_map.search(key, end_lsn);
match res {
Some(res) => {
range_search_result
.found
.entry(OrderedSearchResult(res))
.or_default()
.add_key(key);
}
None => {
range_search_result.not_found.add_key(key);
}
}
key = key.next();
}
range_search_result
}
#[test]
fn ranged_search_on_empty_layer_map() {
let layer_map = LayerMap::default();
let range = Key::from_i128(100)..Key::from_i128(200);
let res = layer_map.range_search(range, Lsn(100));
assert!(res.is_none());
}
#[test]
fn ranged_search() {
let layers = vec![
LayerDesc {
key_range: Key::from_i128(15)..Key::from_i128(50),
lsn_range: Lsn(0)..Lsn(5),
is_delta: false,
},
LayerDesc {
key_range: Key::from_i128(10)..Key::from_i128(20),
lsn_range: Lsn(5)..Lsn(20),
is_delta: true,
},
LayerDesc {
key_range: Key::from_i128(15)..Key::from_i128(25),
lsn_range: Lsn(20)..Lsn(30),
is_delta: true,
},
LayerDesc {
key_range: Key::from_i128(35)..Key::from_i128(40),
lsn_range: Lsn(25)..Lsn(35),
is_delta: true,
},
LayerDesc {
key_range: Key::from_i128(35)..Key::from_i128(40),
lsn_range: Lsn(35)..Lsn(40),
is_delta: false,
},
];
let layer_map = create_layer_map(layers.clone());
for start in 0..60 {
for end in (start + 1)..60 {
let range = Key::from_i128(start)..Key::from_i128(end);
let result = layer_map.range_search(range.clone(), Lsn(100)).unwrap();
let expected = brute_force_range_search(&layer_map, range, Lsn(100));
assert_range_search_result_eq(result, expected);
}
}
}
}

View File

@@ -129,42 +129,6 @@ impl<Value: Clone> LayerCoverage<Value> {
.map(|(k, v)| (*k, v.as_ref().map(|x| x.1.clone())))
}
/// Returns an iterator which includes all coverage changes for layers that intersect
/// with the provided range.
pub fn range_overlaps(
&self,
key_range: &Range<i128>,
) -> impl Iterator<Item = (i128, Option<Value>)> + '_
where
Value: Eq,
{
let first_change = self.query(key_range.start);
match first_change {
Some(change) => {
// If the start of the range is covered, we have to deal with two cases:
// 1. Start of the range is aligned with the start of a layer.
// In this case the return of `self.range` will contain the layer which aligns with the start of the key range.
// We advance said iterator to avoid duplicating the first change.
// 2. Start of the range is not aligned with the start of a layer.
let range = key_range.start..key_range.end;
let mut range_coverage = self.range(range).peekable();
if range_coverage
.peek()
.is_some_and(|c| c.1.as_ref() == Some(&change))
{
range_coverage.next();
}
itertools::Either::Left(
std::iter::once((key_range.start, Some(change))).chain(range_coverage),
)
}
None => {
let range = key_range.start..key_range.end;
let coverage = self.range(range);
itertools::Either::Right(coverage)
}
}
}
/// O(1) clone
pub fn clone(&self) -> Self {
Self {

View File

@@ -55,13 +55,13 @@ impl PersistentLayerDesc {
}
#[cfg(test)]
pub fn new_test(key_range: Range<Key>, lsn_range: Range<Lsn>, is_delta: bool) -> Self {
pub fn new_test(key_range: Range<Key>) -> Self {
Self {
tenant_shard_id: TenantShardId::unsharded(TenantId::generate()),
timeline_id: TimelineId::generate(),
key_range,
lsn_range,
is_delta,
lsn_range: Lsn(0)..Lsn(1),
is_delta: false,
file_size: 0,
}
}

View File

@@ -113,40 +113,52 @@ impl WalIngest {
self.checkpoint_modified = true;
}
// BEGIN ONE-OFF HACK, version 3
// BEGIN ONE-OFF HACK
//
// We had a bug where we incorrectly passed 0 to update_next_xid(). That was
// harmless as long as nextXid was < 2^31, because 0 looked like a very old
// XID. But once nextXid reaches 2^31, 0 starts to look like a very new XID, and
// we incorrectly bumped up nextXid to the next epoch, to value '1:1024'
//
// That bug was fixed in commits e4898a6e605e791a00ce21bf49d4cc0d9a10534a and
// c1148dc9acf938d912888ecb0a4e76ed40e21ef8, but we have one known timeline in
// production where that already happened. This is a one-off fix to fix that
// damage.
// We have one known timeline in production where that happened. This is a one-off
// fix to fix that damage. The last WAL record on that timeline as of this writing
// is this:
//
// So on that particular timeline, fix the incorrectly set nextXid to the XID from
// the next record we see, plus 10000 to give some safety margin.
// rmgr: Standby len (rec/tot): 50/ 50, tx: 0, lsn: 35A/E32D86D8, prev 35A/E32D86B0, desc: RUNNING_XACTS nextXid 2325447052 latestCompletedXid 2325447051 oldestRunningXid 2325447052
//
// As a safety measure, disable this hack after they have reached LSN 380/00000000.
// As of this writing, they are around LSN 36D/00000000. They should not reach
// real XID wraparound until this LSN. We should remove this hack from production
// before that happens anyway, but better safe than sorry.
//
if self.checkpoint.nextXid.value == 4294968320 && // 1::1024, the incorrect value
// only apply this on the one broken tenant
modification.tline.tenant_shard_id.tenant_id == TenantId::from_hex("df254570a4f603805528b46b0d45a76c").unwrap() &&
lsn < Lsn::from_str("380/00000000").unwrap() &&
decoded.xl_xid != pg_constants::INVALID_TRANSACTION_ID
// So on that particular timeline, before that LSN, fix the incorrectly set
// nextXid to the nextXid value from that record, plus 1000 to give some safety
// margin.
// For testing this hack, this failpoint temporarily re-introduces the bug that
// was fixed
fn reintroduce_bug_failpoint_activated() -> bool {
fail::fail_point!("reintroduce-nextxid-update-bug", |_| { true });
false
}
if decoded.xl_xid == pg_constants::INVALID_TRANSACTION_ID
&& reintroduce_bug_failpoint_activated()
&& self.checkpoint.update_next_xid(decoded.xl_xid)
{
self.checkpoint.nextXid = FullTransactionId {
value: (decoded.xl_xid + 10000) as u64,
};
self.checkpoint_modified = true;
warn!(
"nextXid fixed by one-off hack at LSN {}, nextXid is now {}",
info!(
"failpoint: Incorrectly updated nextXid at LSN {} to {}",
lsn, self.checkpoint.nextXid.value
);
self.checkpoint_modified = true;
}
if self.checkpoint.nextXid.value == 4294968320 && // 1::1024, the incorrect value
modification.tline.tenant_shard_id.tenant_id == TenantId::from_hex("df254570a4f603805528b46b0d45a76c").unwrap() &&
lsn < Lsn::from_str("367/C7409300").unwrap() &&
!reintroduce_bug_failpoint_activated()
{
// This is the last nextXid value from the last RUNNING_XACTS record, at the
// end of the WAL as of this writing.
self.checkpoint.nextXid = FullTransactionId {
value: 2399949836 + 1000,
};
self.checkpoint_modified = true;
warn!("nextXid fixed by one-off hack at LSN {}", lsn);
}
// END ONE-OFF HACK
@@ -1404,22 +1416,16 @@ impl WalIngest {
self.checkpoint.nextMultiOffset = xlrec.moff + xlrec.nmembers;
self.checkpoint_modified = true;
}
let max_mbr_xid = xlrec.members.iter().fold(None, |acc, mbr| {
if let Some(max_xid) = acc {
if mbr.xid.wrapping_sub(max_xid) as i32 > 0 {
Some(mbr.xid)
} else {
acc
}
let max_mbr_xid = xlrec.members.iter().fold(0u32, |acc, mbr| {
if mbr.xid.wrapping_sub(acc) as i32 > 0 {
mbr.xid
} else {
Some(mbr.xid)
acc
}
});
if let Some(max_xid) = max_mbr_xid {
if self.checkpoint.update_next_xid(max_xid) {
self.checkpoint_modified = true;
}
if self.checkpoint.update_next_xid(max_mbr_xid) {
self.checkpoint_modified = true;
}
Ok(())
}

View File

@@ -190,10 +190,7 @@ async fn auth_quirks(
Err(info) => {
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
.await?;
ctx.set_endpoint_id(res.info.endpoint.clone());
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
(res.info, Some(res.keys))
}
Ok(info) => (info, None),
@@ -274,12 +271,19 @@ async fn authenticate_with_secret(
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
}
/// wake a compute (or retrieve an existing compute session from cache)
async fn wake_compute(
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
/// only if authentication was successfuly.
async fn auth_and_wake_compute(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
compute_credentials: ComputeCredentials<ComputeCredentialKeys>,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
let compute_credentials =
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
let mut num_retries = 0;
let mut node = loop {
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
@@ -354,16 +358,16 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
"performing authentication using the console"
);
let compute_credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
let (cache_info, user_info) = wake_compute(ctx, &*api, compute_credentials).await?;
let (cache_info, user_info) =
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
.await?;
(cache_info, BackendType::Console(api, user_info))
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {
info!("performing link authentication");
let node_info = link::authenticate(ctx, &url, client).await?;
let node_info = link::authenticate(&url, client).await?;
(
CachedNodeInfo::new_uncached(node_info),

View File

@@ -1,7 +1,6 @@
use crate::{
auth, compute,
console::{self, provider::NodeInfo},
context::RequestMonitoring,
error::UserFacingError,
stream::PqStream,
waiters,
@@ -55,7 +54,6 @@ pub fn new_psql_session_id() -> String {
}
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
@@ -96,10 +94,6 @@ pub(super) async fn authenticate(
.dbname(&db_info.dbname)
.user(&db_info.user);
ctx.set_user(db_info.user.into());
ctx.set_project(db_info.aux.clone());
tracing::Span::current().record("ep", &tracing::field::display(&db_info.aux.endpoint_id));
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.

View File

@@ -2,8 +2,7 @@
use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, EndpointId, RoleName,
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -55,10 +54,10 @@ impl ComputeUserInfoMaybeEndpoint {
}
}
pub fn endpoint_sni(
sni: &str,
pub fn endpoint_sni<'a>(
sni: &'a str,
common_names: &HashSet<String>,
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
) -> Result<&'a str, ComputeUserInfoParseError> {
let Some((subdomain, common_name)) = sni.split_once('.') else {
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
};
@@ -67,10 +66,7 @@ pub fn endpoint_sni(
cn: common_name.into(),
});
}
if subdomain == SERVERLESS_DRIVER_SNI {
return Ok(None);
}
Ok(Some(EndpointId::from(subdomain)))
Ok(subdomain)
}
impl ComputeUserInfoMaybeEndpoint {
@@ -89,6 +85,7 @@ impl ComputeUserInfoMaybeEndpoint {
// record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from));
ctx.set_user(user.clone());
ctx.set_endpoint_id(sni.map(EndpointId::from));
// Project name might be passed via PG's command-line options.
let endpoint_option = params
@@ -106,7 +103,7 @@ impl ComputeUserInfoMaybeEndpoint {
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
Some(EndpointId::from(endpoint_sni(sni_str, cn)?))
} else {
None
}
@@ -120,18 +117,13 @@ impl ComputeUserInfoMaybeEndpoint {
Some(Err(InconsistentProjectNames { domain, option }))
}
// Invariant: project name may not contain certain characters.
(a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
false => Err(MalformedProjectName(name)),
true => Ok(name),
}),
}
.transpose()?;
if let Some(ep) = &endpoint {
ctx.set_endpoint_id(ep.clone());
tracing::Span::current().record("ep", &tracing::field::display(ep));
}
info!(%user, project = endpoint.as_deref(), "credentials");
if sni.is_some() {
info!("Connection with sni");
@@ -154,7 +146,7 @@ impl ComputeUserInfoMaybeEndpoint {
Ok(Self {
user,
endpoint_id: endpoint,
endpoint_id: endpoint.map(EndpointId::from),
options,
})
}

View File

@@ -272,5 +272,5 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await
proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
}

View File

@@ -1,7 +1,7 @@
use anyhow::Context;
use anyhow::{bail, Context};
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
@@ -25,31 +25,39 @@ impl CancelMap {
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub fn get_session(self: Arc<Self>) -> Session {
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = loop {
let key = rand::random();
let key = rand::random();
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => continue,
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
}
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => {
bail!("query cancellation key already exists: {key}")
}
break key;
};
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
}
}
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}");
Session {
key,
cancel_map: self,
}
let session = Session::new(key, self);
f(session).await
}
#[cfg(test)]
@@ -90,17 +98,23 @@ impl CancelClosure {
}
/// Helper for registering query cancellation tokens.
pub struct Session {
pub struct Session<'a> {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: Arc<CancelMap>,
cancel_map: &'a CancelMap,
}
impl Session {
impl<'a> Session<'a> {
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
Self { key, cancel_map }
}
}
impl Session<'_> {
/// Store the cancel token for the given session.
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map.0.insert(self.key, Some(cancel_closure));
@@ -108,26 +122,37 @@ impl Session {
}
}
impl Drop for Session {
fn drop(&mut self) {
self.cancel_map.0.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
}
}
#[cfg(test)]
mod tests {
use super::*;
use once_cell::sync::Lazy;
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancel_map: Arc<CancelMap> = Default::default();
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
Ok(())
}));
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;
// Drop the session's entry by cancelling the task.
task.abort();
let error = task.await.expect_err("task should have failed");
if !error.is_cancelled() {
anyhow::bail!(error);
}
let session = cancel_map.clone().get_session();
assert!(cancel_map.contains(&session));
drop(session);
// Check that the session has been dropped.
assert!(cancel_map.is_empty());
assert!(CANCEL_MAP.is_empty());
Ok(())
}

View File

@@ -89,11 +89,8 @@ impl RequestMonitoring {
self.project = Some(x.project_id);
}
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
crate::metrics::CONNECTING_ENDPOINTS
.with_label_values(&[self.protocol])
.measure(&endpoint_id);
self.endpoint_id = Some(endpoint_id);
pub fn set_endpoint_id(&mut self, endpoint_id: Option<EndpointId>) {
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
}
pub fn set_application(&mut self, app: Option<SmolStr>) {

View File

@@ -1,7 +1,10 @@
use ::metrics::{
exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge_vec, Histogram,
HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGaugeVec,
exponential_buckets, register_int_counter_pair_vec, register_int_counter_vec,
IntCounterPairVec, IntCounterVec,
};
use prometheus::{
register_histogram, register_histogram_vec, register_int_gauge_vec, Histogram, HistogramVec,
IntGaugeVec,
};
use once_cell::sync::Lazy;
@@ -233,13 +236,3 @@ pub const fn bool_to_str(x: bool) -> &'static str {
"false"
}
}
pub static CONNECTING_ENDPOINTS: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
register_hll_vec!(
32,
"proxy_connecting_endpoints",
"HLL approximate cardinality of endpoints that are connecting",
&["protocol"],
)
.unwrap()
});

View File

@@ -2,34 +2,37 @@
mod tests;
pub mod connect_compute;
pub mod handshake;
pub mod passthrough;
pub mod retry;
use crate::{
auth,
cancellation::{self, CancelMap},
compute,
config::{ProxyConfig, TlsConfig},
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
console::messages::MetricsAuxInfo,
context::RequestMonitoring,
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
metrics::{
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
},
protocol2::WithClientIp,
proxy::{handshake::handshake, passthrough::proxy_pass},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
usage_metrics::{Ids, USAGE_METRICS},
EndpointCacheKey,
};
use anyhow::{bail, Context};
use futures::TryFutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use utils::measured_stream::MeasuredStream;
use self::connect_compute::{connect_to_compute, TcpMechanism};
@@ -77,13 +80,6 @@ pub async fn task_main(
let cancel_map = Arc::clone(&cancel_map);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let session_span = info_span!(
"handle_client",
?session_id,
peer_addr = tracing::field::Empty,
ep = tracing::field::Empty,
);
connections.spawn(
async move {
info!("accepted postgres client connection");
@@ -107,18 +103,22 @@ pub async fn task_main(
handle_client(
config,
&mut ctx,
cancel_map,
&cancel_map,
socket,
ClientMode::Tcp,
endpoint_rate_limiter,
)
.await
}
.instrument(info_span!(
"handle_client",
?session_id,
peer_addr = tracing::field::Empty
))
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
.instrument(session_span),
error!(?session_id, "per-client task finished with an error: {e:#}");
}),
);
}
@@ -171,7 +171,7 @@ impl ClientMode {
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
cancel_map: Arc<CancelMap>,
cancel_map: &CancelMap,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -192,88 +192,138 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let tls = config.tls_config.as_ref();
let pause = ctx.latency_timer.pause();
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
drop(pause);
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let user_info = {
let hostname = mode.hostname(stream.get_ref());
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
let common_names = tls.map(|tls| &tls.common_names);
let result = config
.auth_backend
.as_ref()
.map(|_| {
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names)
})
.transpose();
match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
}
};
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await;
ctx.set_endpoint_id(user_info.get_endpoint());
let client = Client::new(
stream,
user_info,
&params,
mode.allow_self_signed_compute(config),
endpoint_rate_limiter,
);
cancel_map
.with_session(|session| {
client.connect_to_db(ctx, session, mode, &config.authentication_config)
})
.await
}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empy.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!(ERR_PROTO_VIOLATION),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
}
}
}
let user = user_info.get_user().to_owned();
let (mut node_info, user_info) = match user_info
.authenticate(
ctx,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream.throw_error(e).instrument(params_span).await;
}
};
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
let aux = node_info.aux.clone();
let mut node = connect_to_compute(
ctx,
&TcpMechanism { params: &params },
node_info,
&user_info,
)
.or_else(|e| stream.throw_error(e))
.await?;
let session = cancel_map.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(ctx, stream, node.stream, aux).await
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: &cancellation::Session,
session: cancellation::Session<'_>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<()> {
// Register compute's query cancellation token and produce a new, unique one.
@@ -299,6 +349,151 @@ async fn prepare_client_connection(
Ok(())
}
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub async fn proxy_pass(
ctx: &mut RequestMonitoring,
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
ctx.set_success();
ctx.log();
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
});
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
m_sent2.inc_by(cnt as u64);
usage.record_egress(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
m_recv2.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}
/// Thin connection context.
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<Stream<S>>,
/// Client credentials that we care about.
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
/// KV-dictionary with PostgreSQL connection params.
params: &'a StartupMessageParams,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
/// Rate limiter for endpoints
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(
stream: PqStream<Stream<S>>,
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
params: &'a StartupMessageParams,
allow_self_signed_compute: bool,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
Self {
stream,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
// Instrumentation logs endpoint name everywhere. Doesn't work for link
// auth; strictly speaking we don't know endpoint name in its case.
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
async fn connect_to_db(
self,
ctx: &mut RequestMonitoring,
session: cancellation::Session<'_>,
mode: ClientMode,
config: &'static AuthenticationConfig,
) -> anyhow::Result<()> {
let Self {
mut stream,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
} = self;
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await;
}
}
let user = user_info.get_user().to_owned();
let auth_result = match user_info
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream.throw_error(e).instrument(params_span).await;
}
};
let (mut node_info, user_info) = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let aux = node_info.aux.clone();
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
.or_else(|e| stream.throw_error(e))
.await?;
prepare_client_connection(&node, session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(ctx, stream, node.stream, aux).await
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);

View File

@@ -1,96 +0,0 @@
use anyhow::{bail, Context};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use crate::{
cancellation::CancelMap,
config::TlsConfig,
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
stream::{PqStream, Stream},
};
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empy.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!(ERR_PROTO_VIOLATION),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
}
}
}
}

View File

@@ -1,57 +0,0 @@
use crate::{
console::messages::MetricsAuxInfo,
context::RequestMonitoring,
metrics::{NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER},
usage_metrics::{Ids, USAGE_METRICS},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub async fn proxy_pass(
ctx: &mut RequestMonitoring,
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
ctx.set_success();
ctx.log();
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
});
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
m_sent2.inc_by(cnt as u64);
usage.record_egress(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
m_recv2.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}

View File

@@ -41,8 +41,6 @@ use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
pub const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
ws_listener: TcpListener,
@@ -230,7 +228,7 @@ async fn request_handler(
config,
&mut ctx,
websocket,
cancel_map,
&cancel_map,
host,
endpoint_rate_limiter,
)

View File

@@ -1,7 +1,6 @@
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
@@ -36,11 +35,11 @@ use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::EndpointId;
use crate::RoleName;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
struct QueryData {
@@ -62,6 +61,7 @@ enum Payload {
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
const SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART: &str = "api";
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
@@ -188,8 +188,10 @@ fn get_conn_info(
}
}
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
ctx.set_endpoint_id(endpoint.clone());
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
let endpoint: EndpointId = endpoint.into();
ctx.set_endpoint_id(Some(endpoint.clone()));
let pairs = connection_url.query_pairs();
@@ -225,7 +227,8 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Err
let (_, hostname_rest) = hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
Ok(sni_hostname_rest == hostname_rest
&& sni_hostname_first == SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART)
}
// TODO: return different http error codes

View File

@@ -133,7 +133,7 @@ pub async fn serve_websocket(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
websocket: HyperWebsocket,
cancel_map: Arc<CancelMap>,
cancel_map: &CancelMap,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {

View File

@@ -28,7 +28,7 @@ use crate::safekeeper::Term;
use crate::safekeeper::{ServerInfo, TermLsn};
use crate::send_wal::WalSenderState;
use crate::timeline::PeerInfo;
use crate::{copy_timeline, debug_dump, patch_control_file, pull_timeline};
use crate::{copy_timeline, debug_dump, pull_timeline};
use crate::timelines_global_map::TimelineDeleteForceResult;
use crate::GlobalTimelines;
@@ -465,26 +465,6 @@ async fn dump_debug_handler(mut request: Request<Body>) -> Result<Response<Body>
Ok(response)
}
async fn patch_control_file_handler(
mut request: Request<Body>,
) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let ttid = TenantTimelineId::new(
parse_request_param(&request, "tenant_id")?,
parse_request_param(&request, "timeline_id")?,
);
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let patch_request: patch_control_file::Request = json_request(&mut request).await?;
let response = patch_control_file::handle_request(tli, patch_request)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, response)
}
/// Safekeeper http router.
pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError> {
let mut router = endpoint::make_router();
@@ -546,10 +526,6 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
"/v1/tenant/:tenant_id/timeline/:source_timeline_id/copy",
|r| request_span(r, timeline_copy_handler),
)
.patch(
"/v1/tenant/:tenant_id/timeline/:timeline_id/control_file",
|r| request_span(r, patch_control_file_handler),
)
// for tests
.post("/v1/record_safekeeper_info/:tenant_id/:timeline_id", |r| {
request_span(r, record_safekeeper_info)

View File

@@ -22,7 +22,6 @@ pub mod handler;
pub mod http;
pub mod json_ctrl;
pub mod metrics;
pub mod patch_control_file;
pub mod pull_timeline;
pub mod receive_wal;
pub mod recovery;

View File

@@ -1,85 +0,0 @@
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::info;
use crate::{state::TimelinePersistentState, timeline::Timeline};
#[derive(Deserialize, Debug, Clone)]
pub struct Request {
/// JSON object with fields to update
pub updates: serde_json::Value,
/// List of fields to apply
pub apply_fields: Vec<String>,
}
#[derive(Serialize)]
pub struct Response {
pub old_control_file: TimelinePersistentState,
pub new_control_file: TimelinePersistentState,
}
/// Patch control file with given request. Will update the persistent state using
/// fields from the request and persist the new state on disk.
pub async fn handle_request(tli: Arc<Timeline>, request: Request) -> anyhow::Result<Response> {
let response = tli
.map_control_file(|state| {
let old_control_file = state.clone();
let new_control_file = state_apply_diff(&old_control_file, &request)?;
info!(
"patching control file, old: {:?}, new: {:?}, patch: {:?}",
old_control_file, new_control_file, request
);
*state = new_control_file.clone();
Ok(Response {
old_control_file,
new_control_file,
})
})
.await?;
Ok(response)
}
fn state_apply_diff(
state: &TimelinePersistentState,
request: &Request,
) -> anyhow::Result<TimelinePersistentState> {
let mut json_value = serde_json::to_value(state)?;
if let Value::Object(a) = &mut json_value {
if let Value::Object(b) = &request.updates {
json_apply_diff(a, b, &request.apply_fields)?;
} else {
anyhow::bail!("request.updates is not a json object")
}
} else {
anyhow::bail!("TimelinePersistentState is not a json object")
}
let new_state: TimelinePersistentState = serde_json::from_value(json_value)?;
Ok(new_state)
}
fn json_apply_diff(
object: &mut serde_json::Map<String, Value>,
updates: &serde_json::Map<String, Value>,
apply_keys: &Vec<String>,
) -> anyhow::Result<()> {
for key in apply_keys {
if let Some(new_value) = updates.get(key) {
if let Some(existing_value) = object.get_mut(key) {
*existing_value = new_value.clone();
} else {
anyhow::bail!("key not found in original object: {}", key);
}
} else {
anyhow::bail!("key not found in request.updates: {}", key);
}
}
Ok(())
}

View File

@@ -901,20 +901,6 @@ impl Timeline {
file_open,
}
}
/// Apply a function to the control file state and persist it.
pub async fn map_control_file<T>(
&self,
f: impl FnOnce(&mut TimelinePersistentState) -> Result<T>,
) -> Result<T> {
let mut state = self.write_shared_state().await;
let mut persistent_state = state.sk.state.start_change();
// If f returns error, we abort the change and don't persist anything.
let res = f(&mut persistent_state)?;
// If persisting fails, we abort the change and return error.
state.sk.state.finish_change(&persistent_state).await?;
Ok(res)
}
}
/// Deletes directory and it's contents. Returns false if directory does not exist.

View File

@@ -3160,6 +3160,23 @@ class Endpoint(PgProtocol):
):
self.stop()
def log_contains(self, pattern: str) -> Optional[str]:
"""Check that the compute log contains a line that matches the given regex"""
logfile = self.endpoint_path() / "compute.log"
if not logfile.exists():
log.warning(f"Skipping log check: {logfile} does not exist")
return None
contains_re = re.compile(pattern)
with logfile.open("r") as f:
for line in f:
if contains_re.search(line):
# found it!
return line
return None
# Checkpoints running endpoint and returns pg_wal size in MB.
def get_pg_wal_size(self):
log.info(f'checkpointing at LSN {self.safe_psql("select pg_current_wal_lsn()")[0][0]}')
@@ -3443,24 +3460,6 @@ class SafekeeperHttpClient(requests.Session):
assert isinstance(res_json, dict)
return res_json
def patch_control_file(
self,
tenant_id: TenantId,
timeline_id: TimelineId,
patch: Dict[str, Any],
) -> Dict[str, Any]:
res = self.patch(
f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/control_file",
json={
"updates": patch,
"apply_fields": list(patch.keys()),
},
)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
def pull_timeline(self, body: Dict[str, Any]) -> Dict[str, Any]:
res = self.post(f"http://localhost:{self.port}/v1/pull_timeline", json=body)
res.raise_for_status()

View File

@@ -3,10 +3,12 @@ import os
import time
from pathlib import Path
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_wal_insert_lsn
from fixtures.pageserver.utils import (
wait_for_last_record_lsn,
wait_for_upload,
)
from fixtures.remote_storage import RemoteStorageKind
from fixtures.types import Lsn, TenantId, TimelineId
@@ -98,7 +100,7 @@ def test_import_at_2bil(
vanilla_pg.safe_psql("CREATE TABLE t (t text);")
vanilla_pg.safe_psql("INSERT INTO t VALUES ('inserted in vanilla')")
endpoint_id = "ep-import_from_vanilla"
branch_name = "import_from_vanilla"
tenant = TenantId.generate()
timeline = TimelineId.generate()
@@ -138,7 +140,7 @@ def test_import_at_2bil(
"--timeline-id",
str(timeline),
"--node-name",
endpoint_id,
branch_name,
"--base-lsn",
start_lsn,
"--base-tarfile",
@@ -157,7 +159,8 @@ def test_import_at_2bil(
wait_for_last_record_lsn(ps_http, tenant, timeline, Lsn(end_lsn))
endpoint = env.endpoints.create_start(
endpoint_id,
branch_name,
endpoint_id="ep-import_from_vanilla",
tenant_id=tenant,
config_lines=[
"log_autovacuum_min_duration = 0",
@@ -166,7 +169,6 @@ def test_import_at_2bil(
)
assert endpoint.safe_psql("select count(*) from t") == [(1,)]
# Ok, consume
conn = endpoint.connect()
cur = conn.cursor()
@@ -203,16 +205,6 @@ def test_import_at_2bil(
$$;
"""
)
# Also create a multi-XID with members past the 2 billion mark
conn2 = endpoint.connect()
cur2 = conn2.cursor()
cur.execute("INSERT INTO t VALUES ('x')")
cur.execute("BEGIN; select * from t WHERE t = 'x' FOR SHARE;")
cur2.execute("BEGIN; select * from t WHERE t = 'x' FOR SHARE;")
cur.execute("COMMIT")
cur2.execute("COMMIT")
# A checkpoint writes a WAL record with xl_xid=0. Many other WAL
# records would have the same effect.
cur.execute("checkpoint")
@@ -227,4 +219,213 @@ def test_import_at_2bil(
conn = endpoint.connect()
cur = conn.cursor()
cur.execute("SELECT count(*) from t")
assert cur.fetchone() == (10000 + 1 + 1,)
assert cur.fetchone() == (10000 + 1,)
# This is a followup to the test_import_at_2bil test.
#
# Use a failpoint to reintroduce the bug that test_import_at_2bil also
# tests. Then, after the damage has been done, clear the failpoint to
# fix the bug. Check that the one-off hack that we added for a particular
# timeline that hit this in production fixes the broken timeline.
def test_one_off_hack_for_nextxid_bug(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_distrib_dir: Path,
pg_bin,
vanilla_pg,
):
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
env = neon_env_builder.init_start()
ps_http = env.pageserver.http_client()
env.pageserver.allowed_errors.append(".*nextXid fixed by one-off hack.*")
# We begin with the old bug still present, to create a broken timeline
ps_http.configure_failpoints(("reintroduce-nextxid-update-bug", "return(true)"))
# Set LD_LIBRARY_PATH in the env properly, otherwise we may use the wrong libpq.
# PgBin sets it automatically, but here we need to pipe psql output to the tar command.
psql_env = {"LD_LIBRARY_PATH": str(pg_distrib_dir / "lib")}
# Reset the vanilla Postgres instance to somewhat before 2 billion transactions,
# and around the same LSN as with the production timeline.
pg_resetwal_path = os.path.join(pg_bin.pg_bin_path, "pg_resetwal")
cmd = [
pg_resetwal_path,
"--next-transaction-id=2129920000",
"-l",
"000000010000035A000000E0",
"-D",
str(vanilla_pg.pgdatadir),
]
pg_bin.run_capture(cmd, env=psql_env)
vanilla_pg.start()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
vanilla_pg.safe_psql(
"""create table tt as select 'long string to consume some space' || g
from generate_series(1,300000) g"""
)
assert vanilla_pg.safe_psql("select count(*) from tt") == [(300000,)]
vanilla_pg.safe_psql("CREATE TABLE t (t text);")
vanilla_pg.safe_psql("INSERT INTO t VALUES ('inserted in vanilla')")
branch_name = "import_from_vanilla"
# This is the tenant/timeline that the one-off hack targets
tenant = "df254570a4f603805528b46b0d45a76c"
timeline = TimelineId.generate()
env.pageserver.tenant_create(tenant)
# Take basebackup
basebackup_dir = os.path.join(test_output_dir, "basebackup")
base_tar = os.path.join(basebackup_dir, "base.tar")
wal_tar = os.path.join(basebackup_dir, "pg_wal.tar")
os.mkdir(basebackup_dir)
vanilla_pg.safe_psql("CHECKPOINT")
pg_bin.run(
[
"pg_basebackup",
"-F",
"tar",
"-d",
vanilla_pg.connstr(),
"-D",
basebackup_dir,
]
)
# Get start_lsn and end_lsn
with open(os.path.join(basebackup_dir, "backup_manifest")) as f:
manifest = json.load(f)
start_lsn = manifest["WAL-Ranges"][0]["Start-LSN"]
end_lsn = manifest["WAL-Ranges"][0]["End-LSN"]
def import_tar(base, wal):
env.neon_cli.raw_cli(
[
"timeline",
"import",
"--tenant-id",
str(tenant),
"--timeline-id",
str(timeline),
"--node-name",
branch_name,
"--base-lsn",
start_lsn,
"--base-tarfile",
base,
"--end-lsn",
end_lsn,
"--wal-tarfile",
wal,
"--pg-version",
env.pg_version,
]
)
# Importing correct backup works
import_tar(base_tar, wal_tar)
wait_for_last_record_lsn(ps_http, tenant, timeline, Lsn(end_lsn))
endpoint = env.endpoints.create_start(
branch_name,
endpoint_id="ep-import_from_vanilla",
tenant_id=tenant,
config_lines=[
"log_autovacuum_min_duration = 0",
"autovacuum_naptime='5 s'",
],
)
assert endpoint.safe_psql("select count(*) from t") == [(1,)]
conn = endpoint.connect()
cur = conn.cursor()
# Install extension containing function needed for test
cur.execute("CREATE EXTENSION neon_test_utils")
# Advance nextXid to the target XID, which is somewhat above the 2
# billion mark.
while True:
xid = int(query_scalar(cur, "SELECT txid_current()"))
log.info(f"xid now {xid}")
# Consume 10k transactons at a time until we get to 2^31 - 200k
if xid < (2325447052 - 100000):
cur.execute("select test_consume_xids(50000);")
elif xid < 2325447052 - 10000:
cur.execute("select test_consume_xids(5000);")
else:
break
# Run a bunch of real INSERTs to cross over the 2 billion mark
# Use a begin-exception block to have a separate sub-XID for each insert.
cur.execute(
"""
do $$
begin
for i in 1..10000 loop
-- Use a begin-exception block to generate a new subtransaction on each iteration
begin
insert into t values (i);
exception when others then
raise 'not expected %', sqlerrm;
end;
end loop;
end;
$$;
"""
)
# A checkpoint writes a WAL record with xl_xid=0. Many other WAL
# records would have the same effect.
cur.execute("checkpoint")
# Ok, the nextXid in the pageserver at this LSN should now be incorrectly
# set to 1:1024. Remember this LSN.
broken_lsn = Lsn(query_scalar(cur, "SELECT pg_current_wal_insert_lsn()"))
# Ensure that the broken checkpoint data has reached permanent storage
ps_http.timeline_checkpoint(tenant, timeline)
wait_for_upload(ps_http, tenant, timeline, broken_lsn)
# Now fix the bug, and generate some WAL with XIDs
ps_http.configure_failpoints(("reintroduce-nextxid-update-bug", "off"))
cur.execute("INSERT INTO t VALUES ('after fix')")
fixed_lsn = Lsn(query_scalar(cur, "SELECT pg_current_wal_insert_lsn()"))
log.info(f"nextXid was broken by {broken_lsn}, and fixed again by {fixed_lsn}")
# Stop the original endpoint, we don't need it anymore.
endpoint.stop()
# Test that we cannot start a new endpoint at the broken LSN.
env.neon_cli.create_branch(
"at-broken-lsn", branch_name, ancestor_start_lsn=broken_lsn, tenant_id=tenant
)
endpoint_broken = env.endpoints.create(
"at-broken-lsn",
endpoint_id="ep-at-broken-lsn",
tenant_id=tenant,
)
with pytest.raises(RuntimeError, match="Postgres exited unexpectedly with code 1"):
endpoint_broken.start()
assert endpoint_broken.log_contains(
'Could not open file "pg_xact/0000": No such file or directory'
)
# But after the bug was fixed, the one-off hack fixed the timeline,
# and a later LSN works.
env.neon_cli.create_branch(
"at-fixed-lsn", branch_name, ancestor_start_lsn=fixed_lsn, tenant_id=tenant
)
endpoint_fixed = env.endpoints.create_start(
"at-fixed-lsn", endpoint_id="ep-at-fixed-lsn", tenant_id=tenant
)
conn = endpoint_fixed.connect()
cur = conn.cursor()
cur.execute("SELECT count(*) from t")
# One "inserted in vanilla" row, 10000 in the DO-loop, and one "after fix" row
assert cur.fetchone() == (1 + 10000 + 1,)

View File

@@ -1946,51 +1946,3 @@ def test_timeline_copy(neon_env_builder: NeonEnvBuilder, insert_rows: int):
assert orig_digest == new_digest
# TODO: test timelines can start after copy
def test_patch_control_file(neon_env_builder: NeonEnvBuilder):
neon_env_builder.num_safekeepers = 1
env = neon_env_builder.init_start()
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
endpoint = env.endpoints.create_start("main")
# initialize safekeeper
endpoint.safe_psql("create table t(key int, value text)")
# update control file
res = (
env.safekeepers[0]
.http_client()
.patch_control_file(
tenant_id,
timeline_id,
{
"timeline_start_lsn": "0/1",
},
)
)
timeline_start_lsn_before = res["old_control_file"]["timeline_start_lsn"]
timeline_start_lsn_after = res["new_control_file"]["timeline_start_lsn"]
log.info(f"patch_control_file response: {res}")
log.info(
f"updated control file timeline_start_lsn, before {timeline_start_lsn_before}, after {timeline_start_lsn_after}"
)
assert timeline_start_lsn_after == "0/1"
env.safekeepers[0].stop().start()
# wait/check that safekeeper is alive
endpoint.safe_psql("insert into t values (1, 'payload')")
# check that timeline_start_lsn is updated
res = (
env.safekeepers[0]
.http_client()
.debug_dump({"dump_control_file": "true", "timeline_id": str(timeline_id)})
)
log.info(f"dump_control_file response: {res}")
assert res["timelines"][0]["control_file"]["timeline_start_lsn"] == "0/1"

View File

@@ -51,7 +51,7 @@ memchr = { version = "2" }
nom = { version = "7" }
num-bigint = { version = "0.4" }
num-integer = { version = "0.1", features = ["i128"] }
num-traits = { version = "0.2", features = ["i128", "libm"] }
num-traits = { version = "0.2", features = ["i128"] }
once_cell = { version = "1" }
parquet = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs", default-features = false, features = ["zstd"] }
prost = { version = "0.11" }
@@ -100,7 +100,7 @@ memchr = { version = "2" }
nom = { version = "7" }
num-bigint = { version = "0.4" }
num-integer = { version = "0.1", features = ["i128"] }
num-traits = { version = "0.2", features = ["i128", "libm"] }
num-traits = { version = "0.2", features = ["i128"] }
once_cell = { version = "1" }
parquet = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs", default-features = false, features = ["zstd"] }
prost = { version = "0.11" }