mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-22 23:50:39 +00:00
Compare commits
160 Commits
release-co
...
release-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b147439d6b | ||
|
|
87179e26b3 | ||
|
|
f05df409bd | ||
|
|
f6c0f6c4ec | ||
|
|
62cd3b8d3d | ||
|
|
8d26978ed9 | ||
|
|
35372a8f12 | ||
|
|
6d95a3fe2d | ||
|
|
99726495c7 | ||
|
|
4a4a457312 | ||
|
|
e78d1e2ec6 | ||
|
|
af429b4a62 | ||
|
|
54433c0839 | ||
|
|
40bb9ff62a | ||
|
|
4688b815b1 | ||
|
|
0982ca4636 | ||
|
|
7272d9f7b3 | ||
|
|
37d555aa59 | ||
|
|
cae3e2976b | ||
|
|
51ecd1bb37 | ||
|
|
1e6bb48076 | ||
|
|
1470af0b42 | ||
|
|
f92f92b91b | ||
|
|
dbb205ae92 | ||
|
|
85072b715f | ||
|
|
6c86fe7143 | ||
|
|
66d5fe7f5b | ||
|
|
a1b9528757 | ||
|
|
1423bb8aa2 | ||
|
|
332f064a42 | ||
|
|
c962f2b447 | ||
|
|
446b3f9d28 | ||
|
|
23352dc2e9 | ||
|
|
c65fc5a955 | ||
|
|
3e624581cd | ||
|
|
fedf4f169c | ||
|
|
86d5798108 | ||
|
|
8b4088dd8a | ||
|
|
c91905e643 | ||
|
|
44b4e355a2 | ||
|
|
03666a1f37 | ||
|
|
9c92242ca0 | ||
|
|
a354071dd0 | ||
|
|
758680d4f8 | ||
|
|
1738fd0a96 | ||
|
|
87b7edfc72 | ||
|
|
def05700d5 | ||
|
|
b547681e08 | ||
|
|
0fd211537b | ||
|
|
a83bd4e81c | ||
|
|
ecdad5e6d5 | ||
|
|
d028929945 | ||
|
|
7b0e3db868 | ||
|
|
088eb72dd7 | ||
|
|
d550e3f626 | ||
|
|
8c6b41daf5 | ||
|
|
bbb050459b | ||
|
|
cab498c787 | ||
|
|
6359342ffb | ||
|
|
13285c2a5e | ||
|
|
33790d14a3 | ||
|
|
709b8cd371 | ||
|
|
1c9bbf1a92 | ||
|
|
16163fb850 | ||
|
|
73ccc2b08c | ||
|
|
c719be6474 | ||
|
|
718645e56c | ||
|
|
fbc8c36983 | ||
|
|
5519e42612 | ||
|
|
4157eaf4c5 | ||
|
|
60241127e2 | ||
|
|
f7d5322e8b | ||
|
|
41bb9c5280 | ||
|
|
69c0d61c5c | ||
|
|
63cb8ce975 | ||
|
|
907e4aa3c4 | ||
|
|
0a2a84b766 | ||
|
|
85b12ddd52 | ||
|
|
dd76f1eeee | ||
|
|
8963ac85f9 | ||
|
|
4a488b3e24 | ||
|
|
c4987b0b13 | ||
|
|
84b4821118 | ||
|
|
32ba9811f9 | ||
|
|
a0cd64c4d3 | ||
|
|
84687b743d | ||
|
|
b6f93dcec9 | ||
|
|
4f6c594973 | ||
|
|
a750c14735 | ||
|
|
9ce0dd4e55 | ||
|
|
0e1a336607 | ||
|
|
7fc2912d06 | ||
|
|
fdf231c237 | ||
|
|
1e08b5dccc | ||
|
|
030810ed3e | ||
|
|
62b74bdc2c | ||
|
|
8b7e9ed820 | ||
|
|
5dad89acd4 | ||
|
|
547b2d2827 | ||
|
|
93f29a0065 | ||
|
|
4f36494615 | ||
|
|
0a550f3e7d | ||
|
|
4bb9554e4a | ||
|
|
008616cfe6 | ||
|
|
e61ec94fbc | ||
|
|
e5152551ad | ||
|
|
b0822a5499 | ||
|
|
1fb6ab59e8 | ||
|
|
e16439400d | ||
|
|
e401f66698 | ||
|
|
2fa461b668 | ||
|
|
03d90bc0b3 | ||
|
|
268bc890ea | ||
|
|
8a6ee79f6f | ||
|
|
9052c32b46 | ||
|
|
995e729ebe | ||
|
|
76077e1ddf | ||
|
|
0467d88f06 | ||
|
|
f5eec194e7 | ||
|
|
7e00be391d | ||
|
|
d56599df2a | ||
|
|
9d9aab3680 | ||
|
|
a202b1b5cc | ||
|
|
90f731f3b1 | ||
|
|
7736b748d3 | ||
|
|
9c23333cb3 | ||
|
|
66a99009ba | ||
|
|
5d4c57491f | ||
|
|
73935ea3a2 | ||
|
|
32e595d4dd | ||
|
|
b0d69acb07 | ||
|
|
98355a419a | ||
|
|
cfb03d6cf0 | ||
|
|
d81ef3f962 | ||
|
|
5d62c67e75 | ||
|
|
53d53d5b1e | ||
|
|
29fe6ea47a | ||
|
|
640327ccb3 | ||
|
|
7cf0f6b37e | ||
|
|
03c2c569be | ||
|
|
eff6d4538a | ||
|
|
5ef7782e9c | ||
|
|
73101db8c4 | ||
|
|
bccdfc6d39 | ||
|
|
99595813bb | ||
|
|
fe07b54758 | ||
|
|
a42d173e7b | ||
|
|
e07f689238 | ||
|
|
7831eddc88 | ||
|
|
943b1bc80c | ||
|
|
95a184e9b7 | ||
|
|
3fa17e9d17 | ||
|
|
55e0fd9789 | ||
|
|
2a88889f44 | ||
|
|
5bad8126dc | ||
|
|
27bc242085 | ||
|
|
192b49cc6d | ||
|
|
e1b60f3693 | ||
|
|
2804f5323b | ||
|
|
676adc6b32 |
@@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
|
||||
. "$HOME/.cargo/env" && \
|
||||
cargo --version && rustup --version && \
|
||||
rustup component add llvm-tools rustfmt clippy && \
|
||||
cargo install rustfilt --version ${RUSTFILT_VERSION} && \
|
||||
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \
|
||||
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
|
||||
cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \
|
||||
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \
|
||||
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
|
||||
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \
|
||||
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
|
||||
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
|
||||
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
|
||||
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
|
||||
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
|
||||
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
|
||||
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
|
||||
--features postgres-bundled --no-default-features && \
|
||||
rm -rf /home/nonroot/.cargo/registry && \
|
||||
rm -rf /home/nonroot/.cargo/git
|
||||
|
||||
@@ -57,21 +57,6 @@ use tracing::{error, info};
|
||||
use url::Url;
|
||||
use utils::failpoint_support;
|
||||
|
||||
// Compatibility hack: if the control plane specified any remote-ext-config
|
||||
// use the default value for extension storage proxy gateway.
|
||||
// Remove this once the control plane is updated to pass the gateway URL
|
||||
fn parse_remote_ext_base_url(arg: &str) -> Result<String> {
|
||||
const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str =
|
||||
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local";
|
||||
|
||||
Ok(if arg.starts_with("http") {
|
||||
arg
|
||||
} else {
|
||||
FALLBACK_PG_EXT_GATEWAY_BASE_URL
|
||||
}
|
||||
.to_owned())
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(rename_all = "kebab-case")]
|
||||
struct Cli {
|
||||
@@ -79,9 +64,8 @@ struct Cli {
|
||||
pub pgbin: String,
|
||||
|
||||
/// The base URL for the remote extension storage proxy gateway.
|
||||
/// Should be in the form of `http(s)://<gateway-hostname>[:<port>]`.
|
||||
#[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")]
|
||||
pub remote_ext_base_url: Option<String>,
|
||||
#[arg(short = 'r', long)]
|
||||
pub remote_ext_base_url: Option<Url>,
|
||||
|
||||
/// The port to bind the external listening HTTP server to. Clients running
|
||||
/// outside the compute will talk to the compute through this port. Keep
|
||||
@@ -276,18 +260,4 @@ mod test {
|
||||
fn verify_cli() {
|
||||
Cli::command().debug_assert()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_pg_ext_gateway_base_url() {
|
||||
let arg = "http://pg-ext-s3-gateway2";
|
||||
let result = super::parse_remote_ext_base_url(arg).unwrap();
|
||||
assert_eq!(result, arg);
|
||||
|
||||
let arg = "pg-ext-s3-gateway";
|
||||
let result = super::parse_remote_ext_base_url(arg).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ use std::time::{Duration, Instant};
|
||||
use std::{env, fs};
|
||||
use tokio::spawn;
|
||||
use tracing::{Instrument, debug, error, info, instrument, warn};
|
||||
use url::Url;
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
use utils::lsn::Lsn;
|
||||
use utils::measured_stream::MeasuredReader;
|
||||
@@ -96,7 +97,7 @@ pub struct ComputeNodeParams {
|
||||
pub internal_http_port: u16,
|
||||
|
||||
/// the address of extension storage proxy gateway
|
||||
pub remote_ext_base_url: Option<String>,
|
||||
pub remote_ext_base_url: Option<Url>,
|
||||
|
||||
/// Interval for installed extensions collection
|
||||
pub installed_extensions_collection_interval: u64,
|
||||
|
||||
@@ -83,6 +83,7 @@ use reqwest::StatusCode;
|
||||
use tar::Archive;
|
||||
use tracing::info;
|
||||
use tracing::log::warn;
|
||||
use url::Url;
|
||||
use zstd::stream::read::Decoder;
|
||||
|
||||
use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS};
|
||||
@@ -158,14 +159,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
|
||||
pub async fn download_extension(
|
||||
ext_name: &str,
|
||||
ext_path: &RemotePath,
|
||||
remote_ext_base_url: &str,
|
||||
remote_ext_base_url: &Url,
|
||||
pgbin: &str,
|
||||
) -> Result<u64> {
|
||||
info!("Download extension {:?} from {:?}", ext_name, ext_path);
|
||||
|
||||
// TODO add retry logic
|
||||
let download_buffer =
|
||||
match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await {
|
||||
match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await {
|
||||
Ok(buffer) => buffer,
|
||||
Err(error_message) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
|
||||
@@ -27,6 +27,7 @@ pub use prometheus::{
|
||||
|
||||
pub mod launch_timestamp;
|
||||
mod wrappers;
|
||||
pub use prometheus;
|
||||
pub use wrappers::{CountedReader, CountedWriter};
|
||||
mod hll;
|
||||
pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec};
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, info_span};
|
||||
|
||||
use crate::{FeatureStore, PostHogClient, PostHogClientConfig};
|
||||
|
||||
@@ -26,31 +27,35 @@ impl FeatureResolverBackgroundLoop {
|
||||
pub fn spawn(self: Arc<Self>, handle: &tokio::runtime::Handle, refresh_period: Duration) {
|
||||
let this = self.clone();
|
||||
let cancel = self.cancel.clone();
|
||||
handle.spawn(async move {
|
||||
tracing::info!("Starting PostHog feature resolver");
|
||||
let mut ticker = tokio::time::interval(refresh_period);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ticker.tick() => {}
|
||||
_ = cancel.cancelled() => break
|
||||
}
|
||||
let resp = match this
|
||||
.posthog_client
|
||||
.get_feature_flags_local_evaluation()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
tracing::warn!("Cannot get feature flags: {}", e);
|
||||
continue;
|
||||
handle.spawn(
|
||||
async move {
|
||||
tracing::info!("Starting PostHog feature resolver");
|
||||
let mut ticker = tokio::time::interval(refresh_period);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ticker.tick() => {}
|
||||
_ = cancel.cancelled() => break
|
||||
}
|
||||
};
|
||||
let feature_store = FeatureStore::new_with_flags(resp.flags);
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
let resp = match this
|
||||
.posthog_client
|
||||
.get_feature_flags_local_evaluation()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
tracing::warn!("Cannot get feature flags: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let feature_store = FeatureStore::new_with_flags(resp.flags);
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
tracing::info!("Feature flag updated");
|
||||
}
|
||||
tracing::info!("PostHog feature resolver stopped");
|
||||
}
|
||||
tracing::info!("PostHog feature resolver stopped");
|
||||
});
|
||||
.instrument(info_span!("posthog_feature_resolver")),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn feature_store(&self) -> Arc<FeatureStore> {
|
||||
|
||||
@@ -448,6 +448,18 @@ impl FeatureStore {
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer whether a feature flag is a boolean flag by checking if it has a multivariate filter.
|
||||
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
|
||||
if let Some(flag_config) = self.flags.get(flag_key) {
|
||||
Ok(flag_config.filters.multivariate.is_none())
|
||||
} else {
|
||||
Err(PostHogEvaluationError::NotAvailable(format!(
|
||||
"Not found in the local evaluation spec: {}",
|
||||
flag_key
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostHogClientConfig {
|
||||
@@ -528,7 +540,15 @@ impl PostHogClient {
|
||||
.bearer_auth(&self.config.server_api_key)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let body = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to get feature flags: {}, {}",
|
||||
status,
|
||||
body
|
||||
));
|
||||
}
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
|
||||
|
||||
@@ -264,10 +264,56 @@ mod propagation_of_cached_label_value {
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(histograms, histograms::bench_bucket_scalability);
|
||||
mod histograms {
|
||||
use std::time::Instant;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
use metrics::core::Collector;
|
||||
|
||||
pub fn bench_bucket_scalability(c: &mut Criterion) {
|
||||
let mut g = c.benchmark_group("bucket_scalability");
|
||||
|
||||
for n in [1, 4, 8, 16, 32, 64, 128, 256] {
|
||||
g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| {
|
||||
b.iter_custom(|iters| {
|
||||
let buckets: Vec<f64> = (0..*n).map(|i| i as f64 * 100.0).collect();
|
||||
let histo = metrics::Histogram::with_opts(
|
||||
metrics::prometheus::HistogramOpts::new("name", "help")
|
||||
.buckets(buckets.clone()),
|
||||
)
|
||||
.unwrap();
|
||||
let start = Instant::now();
|
||||
for i in 0..usize::try_from(iters).unwrap() {
|
||||
histo.observe(buckets[i % buckets.len()]);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
// self-test
|
||||
let mfs = histo.collect();
|
||||
assert_eq!(mfs.len(), 1);
|
||||
let metrics = mfs[0].get_metric();
|
||||
assert_eq!(metrics.len(), 1);
|
||||
let histo = metrics[0].get_histogram();
|
||||
let buckets = histo.get_bucket();
|
||||
assert!(
|
||||
buckets
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, b)| b.get_cumulative_count()
|
||||
>= i as u64 * (iters / buckets.len() as u64))
|
||||
);
|
||||
elapsed
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_main!(
|
||||
label_values,
|
||||
single_metric_multicore_scalability,
|
||||
propagation_of_cached_label_value
|
||||
propagation_of_cached_label_value,
|
||||
histograms,
|
||||
);
|
||||
|
||||
/*
|
||||
@@ -290,6 +336,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns]
|
||||
bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns]
|
||||
bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns]
|
||||
bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns]
|
||||
bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns]
|
||||
bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns]
|
||||
bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns]
|
||||
bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns]
|
||||
bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns]
|
||||
|
||||
Results on an i3en.3xlarge instance
|
||||
|
||||
@@ -344,6 +398,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns]
|
||||
bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns]
|
||||
bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns]
|
||||
bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns]
|
||||
bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns]
|
||||
bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns]
|
||||
bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns]
|
||||
bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns]
|
||||
bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns]
|
||||
|
||||
Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor
|
||||
|
||||
@@ -362,5 +424,13 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns]
|
||||
bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns]
|
||||
bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns]
|
||||
bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns]
|
||||
bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns]
|
||||
bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns]
|
||||
bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns]
|
||||
bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns]
|
||||
bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns]
|
||||
|
||||
*/
|
||||
|
||||
@@ -91,4 +91,14 @@ impl FeatureResolver {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
|
||||
if let Some(inner) = &self.inner {
|
||||
inner.feature_store().is_feature_flag_boolean(flag_key)
|
||||
} else {
|
||||
Err(PostHogEvaluationError::NotAvailable(
|
||||
"PostHog integration is not enabled".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3663,6 +3663,46 @@ async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn tenant_evaluate_feature_flag(
|
||||
request: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
|
||||
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
|
||||
|
||||
let flag: String = must_parse_query_param(&request, "flag")?;
|
||||
let as_type: String = must_parse_query_param(&request, "as")?;
|
||||
|
||||
let state = get_state(&request);
|
||||
|
||||
async {
|
||||
let tenant = state
|
||||
.tenant_manager
|
||||
.get_attached_tenant_shard(tenant_shard_id)?;
|
||||
if as_type == "boolean" {
|
||||
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
|
||||
let result = result.map(|_| true).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else if as_type == "multivariate" {
|
||||
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else {
|
||||
// Auto infer the type of the feature flag.
|
||||
let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?;
|
||||
if is_boolean {
|
||||
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
|
||||
let result = result.map(|_| true).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else {
|
||||
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(info_span!("tenant_evaluate_feature_flag", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug()))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Common functionality of all the HTTP API handlers.
|
||||
///
|
||||
/// - Adds a tracing span to each request (by `request_span`)
|
||||
@@ -4039,5 +4079,8 @@ pub fn make_router(
|
||||
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import",
|
||||
|r| api_handler(r, activate_post_import_handler),
|
||||
)
|
||||
.get("/v1/tenant/:tenant_shard_id/feature_flag", |r| {
|
||||
api_handler(r, tenant_evaluate_feature_flag)
|
||||
})
|
||||
.any(handler_404))
|
||||
}
|
||||
|
||||
@@ -1312,11 +1312,44 @@ impl EvictionsWithLowResidenceDuration {
|
||||
//
|
||||
// Roughly logarithmic scale.
|
||||
const STORAGE_IO_TIME_BUCKETS: &[f64] = &[
|
||||
0.000030, // 30 usec
|
||||
0.001000, // 1000 usec
|
||||
0.030, // 30 ms
|
||||
1.000, // 1000 ms
|
||||
30.000, // 30000 ms
|
||||
0.00005, // 50us
|
||||
0.00006, // 60us
|
||||
0.00007, // 70us
|
||||
0.00008, // 80us
|
||||
0.00009, // 90us
|
||||
0.0001, // 100us
|
||||
0.000110, // 110us
|
||||
0.000120, // 120us
|
||||
0.000130, // 130us
|
||||
0.000140, // 140us
|
||||
0.000150, // 150us
|
||||
0.000160, // 160us
|
||||
0.000170, // 170us
|
||||
0.000180, // 180us
|
||||
0.000190, // 190us
|
||||
0.000200, // 200us
|
||||
0.000210, // 210us
|
||||
0.000220, // 220us
|
||||
0.000230, // 230us
|
||||
0.000240, // 240us
|
||||
0.000250, // 250us
|
||||
0.000300, // 300us
|
||||
0.000350, // 350us
|
||||
0.000400, // 400us
|
||||
0.000450, // 450us
|
||||
0.000500, // 500us
|
||||
0.000600, // 600us
|
||||
0.000700, // 700us
|
||||
0.000800, // 800us
|
||||
0.000900, // 900us
|
||||
0.001000, // 1ms
|
||||
0.002000, // 2ms
|
||||
0.003000, // 3ms
|
||||
0.004000, // 4ms
|
||||
0.005000, // 5ms
|
||||
0.01000, // 10ms
|
||||
0.02000, // 20ms
|
||||
0.05000, // 50ms
|
||||
];
|
||||
|
||||
/// VirtualFile fs operation variants.
|
||||
|
||||
@@ -383,7 +383,7 @@ pub struct TenantShard {
|
||||
|
||||
l0_flush_global_state: L0FlushGlobalState,
|
||||
|
||||
feature_resolver: FeatureResolver,
|
||||
pub(crate) feature_resolver: FeatureResolver,
|
||||
}
|
||||
impl std::fmt::Debug for TenantShard {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -5832,6 +5832,7 @@ pub(crate) mod harness {
|
||||
pub conf: &'static PageServerConf,
|
||||
pub tenant_conf: pageserver_api::models::TenantConfig,
|
||||
pub tenant_shard_id: TenantShardId,
|
||||
pub shard_identity: ShardIdentity,
|
||||
pub generation: Generation,
|
||||
pub shard: ShardIndex,
|
||||
pub remote_storage: GenericRemoteStorage,
|
||||
@@ -5899,6 +5900,7 @@ pub(crate) mod harness {
|
||||
conf,
|
||||
tenant_conf,
|
||||
tenant_shard_id,
|
||||
shard_identity,
|
||||
generation,
|
||||
shard,
|
||||
remote_storage,
|
||||
@@ -5960,8 +5962,7 @@ pub(crate) mod harness {
|
||||
&ShardParameters::default(),
|
||||
))
|
||||
.unwrap(),
|
||||
// This is a legacy/test code path: sharding isn't supported here.
|
||||
ShardIdentity::unsharded(),
|
||||
self.shard_identity,
|
||||
Some(walredo_mgr),
|
||||
self.tenant_shard_id,
|
||||
self.remote_storage.clone(),
|
||||
@@ -6083,6 +6084,7 @@ mod tests {
|
||||
use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn};
|
||||
use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery};
|
||||
use utils::id::TenantId;
|
||||
use utils::shard::{ShardCount, ShardNumber};
|
||||
|
||||
use super::*;
|
||||
use crate::DEFAULT_PG_VERSION;
|
||||
@@ -9418,6 +9420,77 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_failed_flush_should_not_update_disk_consistent_lsn() -> anyhow::Result<()> {
|
||||
//
|
||||
// Setup
|
||||
//
|
||||
let harness = TenantHarness::create_custom(
|
||||
"test_failed_flush_should_not_upload_disk_consistent_lsn",
|
||||
pageserver_api::models::TenantConfig::default(),
|
||||
TenantId::generate(),
|
||||
ShardIdentity::new(ShardNumber(0), ShardCount(4), ShardStripeSize(128)).unwrap(),
|
||||
Generation::new(1),
|
||||
)
|
||||
.await?;
|
||||
let (tenant, ctx) = harness.load().await;
|
||||
|
||||
let timeline = tenant
|
||||
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
|
||||
.await?;
|
||||
assert_eq!(timeline.get_shard_identity().count, ShardCount(4));
|
||||
let mut writer = timeline.writer().await;
|
||||
writer
|
||||
.put(
|
||||
*TEST_KEY,
|
||||
Lsn(0x20),
|
||||
&Value::Image(test_img("foo at 0x20")),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
writer.finish_write(Lsn(0x20));
|
||||
drop(writer);
|
||||
timeline.freeze_and_flush().await.unwrap();
|
||||
|
||||
timeline.remote_client.wait_completion().await.unwrap();
|
||||
let disk_consistent_lsn = timeline.get_disk_consistent_lsn();
|
||||
let remote_consistent_lsn = timeline.get_remote_consistent_lsn_projected();
|
||||
assert_eq!(Some(disk_consistent_lsn), remote_consistent_lsn);
|
||||
|
||||
//
|
||||
// Test
|
||||
//
|
||||
|
||||
let mut writer = timeline.writer().await;
|
||||
writer
|
||||
.put(
|
||||
*TEST_KEY,
|
||||
Lsn(0x30),
|
||||
&Value::Image(test_img("foo at 0x30")),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
writer.finish_write(Lsn(0x30));
|
||||
drop(writer);
|
||||
|
||||
fail::cfg(
|
||||
"flush-layer-before-update-remote-consistent-lsn",
|
||||
"return()",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let flush_res = timeline.freeze_and_flush().await;
|
||||
// if flush failed, the disk/remote consistent LSN should not be updated
|
||||
assert!(flush_res.is_err());
|
||||
assert_eq!(disk_consistent_lsn, timeline.get_disk_consistent_lsn());
|
||||
assert_eq!(
|
||||
remote_consistent_lsn,
|
||||
timeline.get_remote_consistent_lsn_projected()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
#[tokio::test]
|
||||
async fn test_simple_bottom_most_compaction_deltas_1() -> anyhow::Result<()> {
|
||||
|
||||
@@ -4767,7 +4767,10 @@ impl Timeline {
|
||||
|| !flushed_to_lsn.is_valid()
|
||||
);
|
||||
|
||||
if flushed_to_lsn < frozen_to_lsn && self.shard_identity.count.count() > 1 {
|
||||
if flushed_to_lsn < frozen_to_lsn
|
||||
&& self.shard_identity.count.count() > 1
|
||||
&& result.is_ok()
|
||||
{
|
||||
// If our layer flushes didn't carry disk_consistent_lsn up to the `to_lsn` advertised
|
||||
// to us via layer_flush_start_rx, then advance it here.
|
||||
//
|
||||
@@ -4946,6 +4949,10 @@ impl Timeline {
|
||||
return Err(FlushLayerError::Cancelled);
|
||||
}
|
||||
|
||||
fail_point!("flush-layer-before-update-remote-consistent-lsn", |_| {
|
||||
Err(FlushLayerError::Other(anyhow!("failpoint").into()))
|
||||
});
|
||||
|
||||
let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1);
|
||||
|
||||
// The new on-disk layers are now in the layer map. We can remove the
|
||||
|
||||
@@ -11,19 +11,7 @@
|
||||
//! - => S3 as the source for the PGDATA instead of local filesystem
|
||||
//!
|
||||
//! TODOs before productionization:
|
||||
//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding.
|
||||
//! => produced image layers likely too small.
|
||||
//! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size.
|
||||
//! - asserts / unwraps need to be replaced with errors
|
||||
//! - don't trust remote objects will be small (=prevent OOMs in those cases)
|
||||
//! - limit all in-memory buffers in size, or download to disk and read from there
|
||||
//! - limit task concurrency
|
||||
//! - generally play nice with other tenants in the system
|
||||
//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits
|
||||
//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc
|
||||
//! - integrate with layer eviction system
|
||||
//! - audit for Tenant::cancel nor Timeline::cancel responsivity
|
||||
//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!)
|
||||
//!
|
||||
//! An incomplete set of TODOs from the Hackathon:
|
||||
//! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest)
|
||||
@@ -44,7 +32,7 @@ use pageserver_api::key::{
|
||||
rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
|
||||
slru_segment_size_to_key,
|
||||
};
|
||||
use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range};
|
||||
use pageserver_api::keyspace::{ShardedRange, singleton_range};
|
||||
use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus};
|
||||
use pageserver_api::reltag::{RelTag, SlruKind};
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
@@ -167,6 +155,7 @@ impl Planner {
|
||||
/// This function is and must remain pure: given the same input, it will generate the same import plan.
|
||||
async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result<Plan> {
|
||||
let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align();
|
||||
anyhow::ensure!(pgdata_lsn.is_valid());
|
||||
|
||||
let datadir = PgDataDir::new(&self.storage).await?;
|
||||
|
||||
@@ -249,14 +238,22 @@ impl Planner {
|
||||
});
|
||||
|
||||
// Assigns parts of key space to later parallel jobs
|
||||
// Note: The image layers produced here may have gaps, meaning,
|
||||
// there is not an image for each key in the layer's key range.
|
||||
// The read path stops traversal at the first image layer, regardless
|
||||
// of whether a base image has been found for a key or not.
|
||||
// (Concept of sparse image layers doesn't exist.)
|
||||
// This behavior is exactly right for the base image layers we're producing here.
|
||||
// But, since no other place in the code currently produces image layers with gaps,
|
||||
// it seems noteworthy.
|
||||
let mut last_end_key = Key::MIN;
|
||||
let mut current_chunk = Vec::new();
|
||||
let mut current_chunk_size: usize = 0;
|
||||
let mut jobs = Vec::new();
|
||||
for task in std::mem::take(&mut self.tasks).into_iter() {
|
||||
if current_chunk_size + task.total_size()
|
||||
> import_config.import_job_soft_size_limit.into()
|
||||
{
|
||||
let task_size = task.total_size(&self.shard);
|
||||
let projected_chunk_size = current_chunk_size.saturating_add(task_size);
|
||||
if projected_chunk_size > import_config.import_job_soft_size_limit.into() {
|
||||
let key_range = last_end_key..task.key_range().start;
|
||||
jobs.push(ChunkProcessingJob::new(
|
||||
key_range.clone(),
|
||||
@@ -266,7 +263,7 @@ impl Planner {
|
||||
last_end_key = key_range.end;
|
||||
current_chunk_size = 0;
|
||||
}
|
||||
current_chunk_size += task.total_size();
|
||||
current_chunk_size = current_chunk_size.saturating_add(task_size);
|
||||
current_chunk.push(task);
|
||||
}
|
||||
jobs.push(ChunkProcessingJob::new(
|
||||
@@ -604,18 +601,18 @@ impl PgDataDirDb {
|
||||
};
|
||||
|
||||
let path = datadir_path.join(rel_tag.to_segfile_name(segno));
|
||||
assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error
|
||||
anyhow::ensure!(filesize % BLCKSZ as usize == 0);
|
||||
let nblocks = filesize / BLCKSZ as usize;
|
||||
|
||||
PgDataDirDbFile {
|
||||
Ok(PgDataDirDbFile {
|
||||
path,
|
||||
filesize,
|
||||
rel_tag,
|
||||
segno,
|
||||
nblocks: Some(nblocks), // first non-cummulative sizes
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect::<anyhow::Result<_, _>>()?;
|
||||
|
||||
// Set cummulative sizes. Do all of that math here, so that later we could easier
|
||||
// parallelize over segments and know with which segments we need to write relsize
|
||||
@@ -650,12 +647,22 @@ impl PgDataDirDb {
|
||||
trait ImportTask {
|
||||
fn key_range(&self) -> Range<Key>;
|
||||
|
||||
fn total_size(&self) -> usize {
|
||||
// TODO: revisit this
|
||||
if is_contiguous_range(&self.key_range()) {
|
||||
contiguous_range_len(&self.key_range()) as usize * 8192
|
||||
fn total_size(&self, shard_identity: &ShardIdentity) -> usize {
|
||||
let range = ShardedRange::new(self.key_range(), shard_identity);
|
||||
let page_count = range.page_count();
|
||||
if page_count == u32::MAX {
|
||||
tracing::warn!(
|
||||
"Import task has non contiguous key range: {}..{}",
|
||||
self.key_range().start,
|
||||
self.key_range().end
|
||||
);
|
||||
|
||||
// Tasks should operate on contiguous ranges. It is unexpected for
|
||||
// ranges to violate this assumption. Calling code handles this by mapping
|
||||
// any task on a non contiguous range to its own image layer.
|
||||
usize::MAX
|
||||
} else {
|
||||
u32::MAX as usize
|
||||
page_count as usize * 8192
|
||||
}
|
||||
}
|
||||
|
||||
@@ -753,6 +760,8 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
layer_writer: &mut ImageLayerWriter,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<usize> {
|
||||
const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024;
|
||||
|
||||
debug!("Importing relation file");
|
||||
|
||||
let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?;
|
||||
@@ -777,7 +786,7 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
assert_eq!(key.len(), 1);
|
||||
assert!(!acc.is_empty());
|
||||
assert!(acc_end > acc_start);
|
||||
if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ {
|
||||
if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE {
|
||||
acc.push(key.pop().unwrap());
|
||||
Ok((acc, acc_start, end))
|
||||
} else {
|
||||
@@ -792,8 +801,8 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
.get_range(&self.path, range_start.into_u64(), range_end.into_u64())
|
||||
.await?;
|
||||
let mut buf = Bytes::from(range_buf);
|
||||
// TODO: batched writes
|
||||
for key in keys {
|
||||
// The writer buffers writes internally
|
||||
let image = buf.split_to(8192);
|
||||
layer_writer.put_image(key, image, ctx).await?;
|
||||
nimages += 1;
|
||||
@@ -846,6 +855,9 @@ impl ImportTask for ImportSlruBlocksTask {
|
||||
debug!("Importing SLRU segment file {}", self.path);
|
||||
let buf = self.storage.get(&self.path).await?;
|
||||
|
||||
// TODO(vlad): Does timestamp to LSN work for imported timelines?
|
||||
// Probably not since we don't append the `xact_time` to it as in
|
||||
// [`WalIngest::ingest_xact_record`].
|
||||
let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?;
|
||||
let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?;
|
||||
let mut blknum = start_blk;
|
||||
|
||||
@@ -6,7 +6,7 @@ use bytes::Bytes;
|
||||
use postgres_ffi::ControlFileData;
|
||||
use remote_storage::{
|
||||
Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing,
|
||||
ListingObject, RemotePath,
|
||||
ListingObject, RemotePath, RemoteStorageConfig,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -22,11 +22,9 @@ pub async fn new(
|
||||
location: &index_part_format::Location,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<RemoteStorageWrapper, anyhow::Error> {
|
||||
// FIXME: we probably want some timeout, and we might be able to assume the max file
|
||||
// size on S3 is 1GiB (postgres segment size). But the problem is that the individual
|
||||
// downloaders don't know enough about concurrent downloads to make a guess on the
|
||||
// expected bandwidth and resulting best timeout.
|
||||
let timeout = std::time::Duration::from_secs(24 * 60 * 60);
|
||||
// Downloads should be reasonably sized. We do ranged reads for relblock raw data
|
||||
// and full reads for SLRU segments which are bounded by Postgres.
|
||||
let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT;
|
||||
let location_storage = match location {
|
||||
#[cfg(feature = "testing")]
|
||||
index_part_format::Location::LocalFs { path } => {
|
||||
@@ -50,9 +48,12 @@ pub async fn new(
|
||||
.import_pgdata_aws_endpoint_url
|
||||
.clone()
|
||||
.map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env
|
||||
concurrency_limit: 100.try_into().unwrap(), // TODO: think about this
|
||||
max_keys_per_list_response: Some(1000), // TODO: think about this
|
||||
upload_storage_class: None, // irrelevant
|
||||
// This matches the default import job concurrency. This is managed
|
||||
// separately from the usual S3 client, but the concern here is bandwidth
|
||||
// usage.
|
||||
concurrency_limit: 128.try_into().unwrap(),
|
||||
max_keys_per_list_response: Some(1000),
|
||||
upload_storage_class: None, // irrelevant
|
||||
},
|
||||
timeout,
|
||||
)
|
||||
|
||||
@@ -17,35 +17,27 @@ pub(super) async fn authenticate(
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
debug!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
debug!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret, ctx);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
async {
|
||||
|
||||
flow.begin(scram).await.map_err(|error| {
|
||||
warn!(?error, "error sending scram acknowledgement");
|
||||
error
|
||||
})?.authenticate().await.map_err(|error| {
|
||||
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
|
||||
AuthFlow::new(client, scram)
|
||||
.authenticate()
|
||||
.await
|
||||
.inspect_err(|error| {
|
||||
warn!(?error, "error processing scram messages");
|
||||
error
|
||||
})
|
||||
}
|
||||
)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
|
||||
auth::AuthError::user_timeout(e)
|
||||
})??;
|
||||
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
|
||||
.map_err(auth::AuthError::user_timeout)??;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
|
||||
@@ -2,7 +2,6 @@ use std::fmt;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use postgres_client::config::SslMode;
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span};
|
||||
@@ -16,6 +15,7 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::client::cplane_proxy_v1;
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::stream::PqStream;
|
||||
@@ -154,11 +154,13 @@ async fn authenticate(
|
||||
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
client.write_message(BeMessage::AuthenticationOk);
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
name: b"client_encoding",
|
||||
value: b"UTF8",
|
||||
});
|
||||
client.write_message(BeMessage::NoticeResponse(&greeting));
|
||||
client.flush().await?;
|
||||
|
||||
// Wait for console response via control plane (see `mgmt`).
|
||||
info!(parent: &span, "waiting for console's reply...");
|
||||
@@ -188,7 +190,7 @@ async fn authenticate(
|
||||
}
|
||||
}
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
|
||||
@@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext(
|
||||
debug!("cleartext auth flow override is enabled, proceeding");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
let auth_flow = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword {
|
||||
let auth_flow = AuthFlow::new(
|
||||
client,
|
||||
auth::CleartextPassword {
|
||||
secret,
|
||||
endpoint: ep,
|
||||
pool: config.thread_pool.clone(),
|
||||
})
|
||||
.await?;
|
||||
drop(paused);
|
||||
// cleartext auth is only allowed to the ws/http protocol.
|
||||
// If we're here, we already received the password in the first message.
|
||||
// Scram protocol will be executed on the proxy side.
|
||||
let auth_outcome = auth_flow.authenticate().await?;
|
||||
},
|
||||
);
|
||||
let auth_outcome = {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// cleartext auth is only allowed to the ws/http protocol.
|
||||
// If we're here, we already received the password in the first message.
|
||||
// Scram protocol will be executed on the proxy side.
|
||||
auth_flow.authenticate().await?
|
||||
};
|
||||
|
||||
let keys = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
@@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication(
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
.await?
|
||||
let payload = AuthFlow::new(client, auth::PasswordHack)
|
||||
.get_password()
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ use crate::control_plane::{
|
||||
};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
@@ -402,7 +403,7 @@ async fn authenticate_with_secret(
|
||||
};
|
||||
|
||||
// we have authenticated the password
|
||||
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
|
||||
client.write_message(BeMessage::AuthenticationOk);
|
||||
|
||||
return Ok(ComputeCredentials { info, keys });
|
||||
}
|
||||
@@ -702,7 +703,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -784,7 +785,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -838,7 +839,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
|
||||
@@ -5,7 +5,6 @@ use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
@@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
@@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::sasl;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::scram::{self};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub(crate) trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
pub(crate) struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub(crate) struct Scram<'a>(
|
||||
pub(crate) &'a scram::ServerSecret,
|
||||
pub(crate) &'a RequestContext,
|
||||
);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
impl Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
|
||||
if channel_binding {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
} else {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
scram::METHODS_WITHOUT_PLUS,
|
||||
))
|
||||
}
|
||||
@@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> {
|
||||
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
pub(crate) struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub(crate) struct CleartextPassword {
|
||||
@@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword {
|
||||
pub(crate) secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub(crate) struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
/// State might contain ancillary data.
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
state: method,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
tls_server_end_point: self.tls_server_end_point,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
|
||||
self.stream
|
||||
.write_message(BeMessage::AuthenticationCleartextPassword);
|
||||
self.stream.flush().await?;
|
||||
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
@@ -133,6 +99,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
self.stream
|
||||
.write_message(BeMessage::AuthenticationCleartextPassword);
|
||||
self.stream.flush().await?;
|
||||
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
@@ -147,7 +117,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
self.stream.write_message(BeMessage::AuthenticationOk);
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
@@ -159,42 +129,36 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
let channel_binding = self.tls_server_end_point;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
// send sasl message.
|
||||
{
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
let sasl = self.state.first_message(channel_binding.supported());
|
||||
self.stream.write_message(sasl);
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
match sasl.method {
|
||||
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
|
||||
_ => {}
|
||||
}
|
||||
// complete sasl handshake.
|
||||
sasl::authenticate(ctx, self.stream, |method| {
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
match method {
|
||||
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
|
||||
}
|
||||
method => return Err(sasl::Error::BadAuthMethod(method.into())),
|
||||
}
|
||||
|
||||
// TODO: make this a metric instead
|
||||
info!("client chooses {}", sasl.method);
|
||||
// TODO: make this a metric instead
|
||||
info!("client chooses {}", method);
|
||||
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
self.tls_server_end_point,
|
||||
))
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
Ok(scram::Exchange::new(secret, rand::random, channel_binding))
|
||||
})
|
||||
.await
|
||||
.map_err(AuthError::Sasl)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
//! This allows connecting to pods/services running in the same Kubernetes cluster from
|
||||
//! the outside. Similar to an ingress controller for HTTPS.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, anyhow, bail, ensure};
|
||||
use clap::Arg;
|
||||
@@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
@@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled};
|
||||
use crate::proxy::{
|
||||
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
|
||||
};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
@@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.parse()?;
|
||||
|
||||
// Configure TLS
|
||||
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
|
||||
let tls_config = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
@@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
))
|
||||
@@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest,
|
||||
tls_config,
|
||||
Some(compute_tls_config),
|
||||
tls_server_end_point,
|
||||
proxy_listener_compute_tls,
|
||||
cancellation_token.clone(),
|
||||
))
|
||||
@@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
pub(super) fn parse_tls(
|
||||
key_path: &Path,
|
||||
cert_path: &Path,
|
||||
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
|
||||
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
|
||||
@@ -187,10 +189,6 @@ pub(super) fn parse_tls(
|
||||
})?
|
||||
};
|
||||
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
@@ -199,14 +197,13 @@ pub(super) fn parse_tls(
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into();
|
||||
|
||||
Ok((tls_config, tls_server_end_point))
|
||||
Ok(tls_config)
|
||||
}
|
||||
|
||||
pub(super) async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -242,15 +239,7 @@ pub(super) async fn task_main(
|
||||
crate::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
handle_client(
|
||||
ctx,
|
||||
dest_suffix,
|
||||
tls_config,
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
)
|
||||
.await
|
||||
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -269,55 +258,26 @@ pub(super) async fn task_main(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::SslRequest;
|
||||
|
||||
) -> anyhow::Result<TlsStream<S>> {
|
||||
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
|
||||
match msg {
|
||||
SslRequest { direct: false } => {
|
||||
stream
|
||||
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
|
||||
.await?;
|
||||
FeStartupPacket::SslRequest { direct: None } => {
|
||||
let raw = stream.accept_tls().await?;
|
||||
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empty.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
Ok(raw
|
||||
.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?)
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
?unexpected,
|
||||
"unexpected startup packet, rejecting connection"
|
||||
);
|
||||
stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None)
|
||||
.await?
|
||||
Err(stream.throw_error(TlsRequired, None).await)?
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,15 +287,18 @@ async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
|
||||
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
|
||||
let sni = tls_stream
|
||||
.get_ref()
|
||||
.1
|
||||
.server_name()
|
||||
.ok_or(anyhow!("SNI missing"))?;
|
||||
let dest: Vec<&str> = sni
|
||||
.split_once('.')
|
||||
.context("invalid SNI")?
|
||||
|
||||
@@ -476,8 +476,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let key_path = args.tls_key.expect("already asserted it is set");
|
||||
let cert_path = args.tls_cert.expect("already asserted it is set");
|
||||
|
||||
let (tls_config, tls_server_end_point) =
|
||||
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
|
||||
let dest = Arc::new(dest);
|
||||
|
||||
@@ -485,7 +484,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
listen,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
@@ -494,7 +492,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest,
|
||||
tls_config,
|
||||
Some(config.connect_to_compute.tls.clone()),
|
||||
tls_server_end_point,
|
||||
listen_tls,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
|
||||
@@ -5,7 +5,6 @@ use anyhow::{Context, anyhow};
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::CancelToken;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::{Cmd, FromRedisValue, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
@@ -21,6 +20,7 @@ use crate::control_plane::ControlPlaneApi;
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
|
||||
use crate::pqproto::CancelKeyData;
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::keys::KeyPrefix;
|
||||
|
||||
@@ -8,7 +8,6 @@ use itertools::Itertools;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
@@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::types::Host;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
@@ -221,12 +221,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e, Some(ctx)).await?;
|
||||
}
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
let node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
@@ -238,7 +236,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
@@ -246,14 +244,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::net::IpAddr;
|
||||
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::field::display;
|
||||
@@ -20,6 +19,7 @@ use crate::metrics::{
|
||||
ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
|
||||
Waiting,
|
||||
};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr;
|
||||
use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr};
|
||||
use parquet::file::writer::SerializedFileWriter;
|
||||
use parquet::record::RecordWriter;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
|
||||
use serde::ser::SerializeMap;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner};
|
||||
use crate::config::remote_storage_from_toml;
|
||||
use crate::context::LOG_CHAN_DISCONNECT;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
pub struct ParquetUploadArgs {
|
||||
|
||||
@@ -92,6 +92,7 @@ mod logging;
|
||||
mod metrics;
|
||||
mod parse;
|
||||
mod pglb;
|
||||
mod pqproto;
|
||||
mod protocol2;
|
||||
mod proxy;
|
||||
mod rate_limiter;
|
||||
|
||||
693
proxy/src/pqproto.rs
Normal file
693
proxy/src/pqproto.rs
Normal file
@@ -0,0 +1,693 @@
|
||||
//! Postgres protocol codec
|
||||
//!
|
||||
//! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
|
||||
use std::fmt;
|
||||
use std::io::{self, Cursor};
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use itertools::Itertools;
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
pub type ErrorCode = [u8; 5];
|
||||
|
||||
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
|
||||
|
||||
/// The protocol version number.
|
||||
///
|
||||
/// The most significant 16 bits are the major version number (3 for the protocol described here).
|
||||
/// The least significant 16 bits are the minor version number (0 for the protocol described here).
|
||||
/// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE>
|
||||
#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
pub struct ProtocolVersion {
|
||||
major: big_endian::U16,
|
||||
minor: big_endian::U16,
|
||||
}
|
||||
|
||||
impl ProtocolVersion {
|
||||
pub const fn new(major: u16, minor: u16) -> Self {
|
||||
Self {
|
||||
major: big_endian::U16::new(major),
|
||||
minor: big_endian::U16::new(minor),
|
||||
}
|
||||
}
|
||||
pub const fn minor(self) -> u16 {
|
||||
self.minor.get()
|
||||
}
|
||||
pub const fn major(self) -> u16 {
|
||||
self.major.get()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ProtocolVersion {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_list()
|
||||
.entry(&self.major())
|
||||
.entry(&self.minor())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// read the type from the stream using zerocopy.
|
||||
///
|
||||
/// not cancel safe.
|
||||
macro_rules! read {
|
||||
($s:expr => $t:ty) => {{
|
||||
// cannot be implemented as a function due to lack of const-generic-expr
|
||||
let mut buf = [0; size_of::<$t>()];
|
||||
$s.read_exact(&mut buf).await?;
|
||||
let res: $t = zerocopy::transmute!(buf);
|
||||
res
|
||||
}};
|
||||
}
|
||||
|
||||
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
|
||||
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
|
||||
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
|
||||
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
|
||||
|
||||
/// This first reads the startup message header, is 8 bytes.
|
||||
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
|
||||
///
|
||||
/// The length value is inclusive of the header. For example,
|
||||
/// an empty message will always have length 8.
|
||||
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
struct StartupHeader {
|
||||
len: big_endian::U32,
|
||||
version: ProtocolVersion,
|
||||
}
|
||||
|
||||
let header = read!(stream => StartupHeader);
|
||||
|
||||
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
|
||||
// First byte indicates standard SSL handshake message
|
||||
// (It can't be a Postgres startup length because in network byte order
|
||||
// that would be a startup packet hundreds of megabytes long)
|
||||
if header.as_bytes()[0] == 0x16 {
|
||||
return Ok(FeStartupPacket::SslRequest {
|
||||
// The bytes we read for the header are actually part of a TLS ClientHello.
|
||||
// In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here.
|
||||
// In practice though, I see no world where a ClientHello is less than 8 bytes
|
||||
// since it includes ephemeral keys etc.
|
||||
direct: Some(zerocopy::transmute!(header)),
|
||||
});
|
||||
}
|
||||
|
||||
let Some(len) = (header.len.get() as usize).checked_sub(8) else {
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {}, must be at least 8.",
|
||||
header.len,
|
||||
)));
|
||||
};
|
||||
|
||||
// TODO: add a histogram for startup packet lengths
|
||||
if len > MAX_STARTUP_PACKET_LENGTH {
|
||||
tracing::warn!("large startup message detected: {len} bytes");
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {len}"
|
||||
)));
|
||||
}
|
||||
|
||||
match header.version {
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-CANCELREQUEST>
|
||||
CANCEL_REQUEST_CODE => {
|
||||
if len != 8 {
|
||||
return Err(io::Error::other(
|
||||
"CancelRequest message is malformed, backend PID / secret key missing",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(FeStartupPacket::CancelRequest(
|
||||
read!(stream => CancelKeyData),
|
||||
))
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST>
|
||||
NEGOTIATE_SSL_CODE => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
Ok(FeStartupPacket::SslRequest { direct: None })
|
||||
}
|
||||
NEGOTIATE_GSS_CODE => {
|
||||
// Requested upgrade to GSSAPI
|
||||
Ok(FeStartupPacket::GssEncRequest)
|
||||
}
|
||||
version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other(
|
||||
format!("Unrecognized request code {version:?}"),
|
||||
)),
|
||||
// StartupMessage
|
||||
version => {
|
||||
// The protocol version number is followed by one or more pairs of parameter name and value strings.
|
||||
// A zero byte is required as a terminator after the last name/value pair.
|
||||
// Parameters can appear in any order. user is required, others are optional.
|
||||
|
||||
let mut buf = vec![0; len];
|
||||
stream.read_exact(&mut buf).await?;
|
||||
|
||||
if buf.pop() != Some(b'\0') {
|
||||
return Err(io::Error::other(
|
||||
"StartupMessage params: missing null terminator",
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Don't do this.
|
||||
// There's no guarantee that these messages are utf8,
|
||||
// but they usually happen to be simple ascii.
|
||||
let params = String::from_utf8(buf)
|
||||
.map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?;
|
||||
|
||||
Ok(FeStartupPacket::StartupMessage {
|
||||
version,
|
||||
params: StartupMessageParams { params },
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
///
|
||||
/// This returns the message tag, as well as the message body. The message
|
||||
/// body is written into `buf`, and it is otherwise completely overwritten.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_message<'a, S>(
|
||||
stream: &mut S,
|
||||
buf: &'a mut Vec<u8>,
|
||||
max: usize,
|
||||
) -> io::Result<(u8, &'a mut [u8])>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
/// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes.
|
||||
/// The first byte is a message tag, and the next 4 bytes is a big-endian length.
|
||||
///
|
||||
/// Awkwardly, the length value is inclusive of itself, but not of the tag. For example,
|
||||
/// an empty message will always have length 4.
|
||||
#[derive(Clone, Copy, FromBytes)]
|
||||
#[repr(C)]
|
||||
struct Header {
|
||||
tag: u8,
|
||||
len: big_endian::U32,
|
||||
}
|
||||
|
||||
let header = read!(stream => Header);
|
||||
|
||||
// as described above, the length must be at least 4.
|
||||
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {}, must be at least 4.",
|
||||
header.len,
|
||||
)));
|
||||
};
|
||||
|
||||
// TODO: add a histogram for message lengths
|
||||
|
||||
// check if the message exceeds our desired max.
|
||||
if len > max {
|
||||
tracing::warn!("large postgres message detected: {len} bytes");
|
||||
return Err(io::Error::other(format!("invalid message length {len}")));
|
||||
}
|
||||
|
||||
// read in our entire message.
|
||||
buf.resize(len, 0);
|
||||
stream.read_exact(buf).await?;
|
||||
|
||||
Ok((header.tag, buf))
|
||||
}
|
||||
|
||||
pub struct WriteBuf(Cursor<Vec<u8>>);
|
||||
|
||||
impl Buf for WriteBuf {
|
||||
#[inline]
|
||||
fn remaining(&self) -> usize {
|
||||
self.0.remaining()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn chunk(&self) -> &[u8] {
|
||||
self.0.chunk()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn advance(&mut self, cnt: usize) {
|
||||
self.0.advance(cnt);
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteBuf {
|
||||
pub const fn new() -> Self {
|
||||
Self(Cursor::new(Vec::new()))
|
||||
}
|
||||
|
||||
/// Use a heuristic to determine if we should shrink the write buffer.
|
||||
#[inline]
|
||||
fn should_shrink(&self) -> bool {
|
||||
let n = self.0.position() as usize;
|
||||
let len = self.0.get_ref().len();
|
||||
|
||||
// the unused space at the front of our buffer is 2x the size of our filled portion.
|
||||
n + n > len
|
||||
}
|
||||
|
||||
/// Shrink the write buffer so that subsequent writes have more spare capacity.
|
||||
#[cold]
|
||||
fn shrink(&mut self) {
|
||||
let n = self.0.position() as usize;
|
||||
let buf = self.0.get_mut();
|
||||
|
||||
// buf repr:
|
||||
// [----unused------|-----filled-----|-----uninit-----]
|
||||
// ^ n ^ buf.len() ^ buf.capacity()
|
||||
let filled = n..buf.len();
|
||||
let filled_len = filled.len();
|
||||
buf.copy_within(filled, 0);
|
||||
buf.truncate(filled_len);
|
||||
self.0.set_position(0);
|
||||
}
|
||||
|
||||
/// clear the write buffer.
|
||||
pub fn reset(&mut self) {
|
||||
let buf = self.0.get_mut();
|
||||
buf.clear();
|
||||
self.0.set_position(0);
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
///
|
||||
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
|
||||
/// we calculate the length after the fact.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
if self.should_shrink() {
|
||||
self.shrink();
|
||||
}
|
||||
|
||||
let buf = self.0.get_mut();
|
||||
buf.reserve(5 + size_hint);
|
||||
|
||||
buf.push(tag);
|
||||
let start = buf.len();
|
||||
buf.extend_from_slice(&[0, 0, 0, 0]);
|
||||
|
||||
f(buf);
|
||||
|
||||
let end = buf.len();
|
||||
let len = (end - start) as u32;
|
||||
buf[start..start + 4].copy_from_slice(&len.to_be_bytes());
|
||||
}
|
||||
|
||||
/// Write an encryption response message.
|
||||
pub fn encryption(&mut self, m: u8) {
|
||||
self.0.get_mut().push(m);
|
||||
}
|
||||
|
||||
pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) {
|
||||
self.shrink();
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE>
|
||||
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
|
||||
// "SERROR\0CXXXXX\0M\0\0".len() == 17
|
||||
self.write_raw(17 + msg.len(), b'E', |buf| {
|
||||
// Severity: ERROR
|
||||
buf.put_slice(b"SERROR\0");
|
||||
|
||||
// Code: error_code
|
||||
buf.put_u8(b'C');
|
||||
buf.put_slice(&error_code);
|
||||
buf.put_u8(0);
|
||||
|
||||
// Message: msg
|
||||
buf.put_u8(b'M');
|
||||
buf.put_slice(msg.as_bytes());
|
||||
buf.put_u8(0);
|
||||
|
||||
// End.
|
||||
buf.put_u8(0);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FeStartupPacket {
|
||||
CancelRequest(CancelKeyData),
|
||||
SslRequest {
|
||||
direct: Option<[u8; 8]>,
|
||||
},
|
||||
GssEncRequest,
|
||||
StartupMessage {
|
||||
version: ProtocolVersion,
|
||||
params: StartupMessageParams,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StartupMessageParams {
|
||||
pub params: String,
|
||||
}
|
||||
|
||||
impl StartupMessageParams {
|
||||
/// Get parameter's value by its name.
|
||||
pub fn get(&self, name: &str) -> Option<&str> {
|
||||
self.iter().find_map(|(k, v)| (k == name).then_some(v))
|
||||
}
|
||||
|
||||
/// Split command-line options according to PostgreSQL's logic,
|
||||
/// taking into account all escape sequences but leaving them as-is.
|
||||
/// [`None`] means that there's no `options` in [`Self`].
|
||||
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
|
||||
self.get("options").map(Self::parse_options_raw)
|
||||
}
|
||||
|
||||
/// Split command-line options according to PostgreSQL's logic,
|
||||
/// taking into account all escape sequences but leaving them as-is.
|
||||
pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
|
||||
// See `postgres: pg_split_opts`.
|
||||
let mut last_was_escape = false;
|
||||
input
|
||||
.split(move |c: char| {
|
||||
// We split by non-escaped whitespace symbols.
|
||||
let should_split = c.is_ascii_whitespace() && !last_was_escape;
|
||||
last_was_escape = c == '\\' && !last_was_escape;
|
||||
should_split
|
||||
})
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
/// Iterate through key-value pairs in an arbitrary order.
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
|
||||
self.params.split_terminator('\0').tuples()
|
||||
}
|
||||
|
||||
// This function is mostly useful in tests.
|
||||
#[cfg(test)]
|
||||
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
|
||||
let mut b = Self {
|
||||
params: String::new(),
|
||||
};
|
||||
for (k, v) in pairs {
|
||||
b.insert(k, v);
|
||||
}
|
||||
b
|
||||
}
|
||||
|
||||
/// Set parameter's value by its name.
|
||||
/// name and value must not contain a \0 byte
|
||||
pub fn insert(&mut self, name: &str, value: &str) {
|
||||
self.params.reserve(name.len() + value.len() + 2);
|
||||
self.params.push_str(name);
|
||||
self.params.push('\0');
|
||||
self.params.push_str(value);
|
||||
self.params.push('\0');
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just
|
||||
/// opaque bytes.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
pub struct CancelKeyData(pub big_endian::U64);
|
||||
|
||||
pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
|
||||
CancelKeyData(big_endian::U64::new(id))
|
||||
}
|
||||
|
||||
impl fmt::Display for CancelKeyData {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let id = self.0;
|
||||
f.debug_tuple("CancelKeyData")
|
||||
.field(&format_args!("{id:x}"))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
impl Distribution<CancelKeyData> for Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
|
||||
id_to_cancel_key(rng.r#gen())
|
||||
}
|
||||
}
|
||||
|
||||
pub enum BeMessage<'a> {
|
||||
AuthenticationOk,
|
||||
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
|
||||
AuthenticationCleartextPassword,
|
||||
BackendKeyData(CancelKeyData),
|
||||
ParameterStatus {
|
||||
name: &'a [u8],
|
||||
value: &'a [u8],
|
||||
},
|
||||
ReadyForQuery,
|
||||
NoticeResponse(&'a str),
|
||||
NegotiateProtocolVersion {
|
||||
version: ProtocolVersion,
|
||||
options: &'a [&'a str],
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BeAuthenticationSaslMessage<'a> {
|
||||
Methods(&'a [&'a str]),
|
||||
Continue(&'a [u8]),
|
||||
Final(&'a [u8]),
|
||||
}
|
||||
|
||||
impl BeMessage<'_> {
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(self, buf: &mut WriteBuf) {
|
||||
match self {
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationCleartextPassword => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
|
||||
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
|
||||
buf.write_raw(len + 2, b'R', |buf| {
|
||||
buf.put_i32(10); // Specifies that SASL auth method is used.
|
||||
for method in methods {
|
||||
buf.put_slice(method.as_bytes());
|
||||
buf.put_u8(0);
|
||||
}
|
||||
buf.put_u8(0); // zero terminator for the list
|
||||
});
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.put_i32(11); // Continue SASL auth.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.put_i32(12); // Send final SASL message.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
|
||||
BeMessage::BackendKeyData(key_data) => {
|
||||
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
|
||||
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
|
||||
BeMessage::NoticeResponse(msg) => {
|
||||
// 'N' signalizes NoticeResponse messages
|
||||
buf.write_raw(18 + msg.len(), b'N', |buf| {
|
||||
// Severity: NOTICE
|
||||
buf.put_slice(b"SNOTICE\0");
|
||||
|
||||
// Code: XX000 (ignored for notice, but still required)
|
||||
buf.put_slice(b"CXX000\0");
|
||||
|
||||
// Message: msg
|
||||
buf.put_u8(b'M');
|
||||
buf.put_slice(msg.as_bytes());
|
||||
buf.put_u8(0);
|
||||
|
||||
// End notice.
|
||||
buf.put_u8(0);
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-PARAMETERSTATUS>
|
||||
BeMessage::ParameterStatus { name, value } => {
|
||||
buf.write_raw(name.len() + value.len() + 2, b'S', |buf| {
|
||||
buf.put_slice(name.as_bytes());
|
||||
buf.put_u8(0);
|
||||
buf.put_slice(value.as_bytes());
|
||||
buf.put_u8(0);
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::ReadyForQuery => {
|
||||
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::NegotiateProtocolVersion { version, options } => {
|
||||
let len: usize = options.iter().map(|o| o.len() + 1).sum();
|
||||
buf.write_raw(8 + len, b'v', |buf| {
|
||||
buf.put_slice(version.as_bytes());
|
||||
buf.put_u32(options.len() as u32);
|
||||
for option in options {
|
||||
buf.put_slice(option.as_bytes());
|
||||
buf.put_u8(0);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io::Cursor;
|
||||
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
|
||||
|
||||
use super::ProtocolVersion;
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_large_startup() {
|
||||
// we're going to define a v3.0 startup message with far too many parameters.
|
||||
let mut payload = vec![];
|
||||
// 10001 + 8 bytes.
|
||||
payload.extend_from_slice(&10009_u32.to_be_bytes());
|
||||
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
|
||||
payload.resize(10009, b'a');
|
||||
|
||||
let (mut server, mut client) = duplex(128);
|
||||
#[rustfmt::skip]
|
||||
let (server, client) = tokio::join!(
|
||||
async move { read_startup(&mut server).await.unwrap_err() },
|
||||
async move { client.write_all(&payload).await.unwrap_err() },
|
||||
);
|
||||
|
||||
assert_eq!(server.to_string(), "invalid startup message length 10001");
|
||||
assert_eq!(client.to_string(), "broken pipe");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_large_password() {
|
||||
// we're going to define a password message that is far too long.
|
||||
let mut payload = vec![];
|
||||
payload.push(b'p');
|
||||
payload.extend_from_slice(&517_u32.to_be_bytes());
|
||||
payload.resize(518, b'a');
|
||||
|
||||
let (mut server, mut client) = duplex(128);
|
||||
#[rustfmt::skip]
|
||||
let (server, client) = tokio::join!(
|
||||
async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() },
|
||||
async move { client.write_all(&payload).await.unwrap_err() },
|
||||
);
|
||||
|
||||
assert_eq!(server.to_string(), "invalid message length 513");
|
||||
assert_eq!(client.to_string(), "broken pipe");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_startup_message() {
|
||||
let mut payload = vec![];
|
||||
payload.extend_from_slice(&17_u32.to_be_bytes());
|
||||
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
|
||||
payload.extend_from_slice(b"abc\0def\0\0");
|
||||
|
||||
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
|
||||
let FeStartupPacket::StartupMessage { version, params } = startup else {
|
||||
panic!("unexpected startup message: {startup:?}");
|
||||
};
|
||||
|
||||
assert_eq!(version.major(), 3);
|
||||
assert_eq!(version.minor(), 0);
|
||||
assert_eq!(params.params, "abc\0def\0");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_ssl_message() {
|
||||
let mut payload = vec![];
|
||||
payload.extend_from_slice(&8_u32.to_be_bytes());
|
||||
payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes());
|
||||
|
||||
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
|
||||
let FeStartupPacket::SslRequest { direct: None } = startup else {
|
||||
panic!("unexpected startup message: {startup:?}");
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_tls_message() {
|
||||
// sample client hello taken from <https://tls13.xargs.org/#client-hello>
|
||||
let client_hello = [
|
||||
0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02,
|
||||
0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
|
||||
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e,
|
||||
0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
|
||||
0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9,
|
||||
0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01,
|
||||
0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00,
|
||||
0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65,
|
||||
0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
|
||||
0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19,
|
||||
0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23,
|
||||
0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e,
|
||||
0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09,
|
||||
0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
|
||||
0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01,
|
||||
0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72,
|
||||
0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38,
|
||||
0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62,
|
||||
0x54,
|
||||
];
|
||||
|
||||
let mut cursor = Cursor::new(&client_hello);
|
||||
|
||||
let startup = read_startup(&mut cursor).await.unwrap();
|
||||
let FeStartupPacket::SslRequest {
|
||||
direct: Some(prefix),
|
||||
} = startup
|
||||
else {
|
||||
panic!("unexpected startup message: {startup:?}");
|
||||
};
|
||||
|
||||
// check that no data is lost.
|
||||
assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]);
|
||||
assert_eq!(cursor.position(), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_message_success() {
|
||||
let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2";
|
||||
let mut cursor = Cursor::new(&query);
|
||||
|
||||
let mut buf = vec![];
|
||||
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
|
||||
assert_eq!(tag, b'Q');
|
||||
assert_eq!(message, b"SELECT 1");
|
||||
|
||||
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
|
||||
assert_eq!(tag, b'Q');
|
||||
assert_eq!(message, b"SELECT 2");
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
use async_trait::async_trait;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
@@ -15,6 +14,7 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::retry::{CouldRetry, retry_after, should_retry};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::types::Host;
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
use bytes::Buf;
|
||||
use pq_proto::framed::Framed;
|
||||
use pq_proto::{
|
||||
BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -12,7 +7,10 @@ use crate::config::TlsConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::ERR_INSECURE_CONNECTION;
|
||||
use crate::pqproto::{
|
||||
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use crate::proxy::TlsRequired;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
@@ -71,33 +69,25 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let have_tls = tls.is_some();
|
||||
if !direct {
|
||||
stream
|
||||
.write_message(&Be::EncryptionResponse(have_tls))
|
||||
.await?;
|
||||
} else if !have_tls {
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let Framed {
|
||||
stream: raw,
|
||||
read_buf,
|
||||
write_buf,
|
||||
} = stream.framed;
|
||||
let mut read_buf;
|
||||
let raw = if let Some(direct) = &direct {
|
||||
read_buf = &direct[..];
|
||||
stream.accept_direct_tls()
|
||||
} else {
|
||||
read_buf = &[];
|
||||
stream.accept_tls().await?
|
||||
};
|
||||
|
||||
let Stream::Raw { raw } = raw else {
|
||||
return Err(HandshakeError::StreamUpgradeError(
|
||||
@@ -105,12 +95,11 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
));
|
||||
};
|
||||
|
||||
let mut read_buf = read_buf.reader();
|
||||
let mut res = Ok(());
|
||||
let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
|
||||
.accept_with(raw, |session| {
|
||||
// push the early data to the tls session
|
||||
while !read_buf.get_ref().is_empty() {
|
||||
while !read_buf.is_empty() {
|
||||
match session.read_tls(&mut read_buf) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
@@ -123,7 +112,6 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
res?;
|
||||
|
||||
let read_buf = read_buf.into_inner();
|
||||
if !read_buf.is_empty() {
|
||||
return Err(HandshakeError::EarlyData);
|
||||
}
|
||||
@@ -157,16 +145,17 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
},
|
||||
let tls = Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
};
|
||||
(stream, msg) = PqStream::parse_startup(tls).await?;
|
||||
} else {
|
||||
if direct.is_some() {
|
||||
// client sent us a ClientHello already, we can't do anything with it.
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
msg = stream.reject_encryption().await?;
|
||||
}
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
@@ -176,7 +165,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
tried_gss = true;
|
||||
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
msg = stream.reject_encryption().await?;
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
@@ -186,13 +175,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
return stream
|
||||
.throw_error_str(
|
||||
ERR_INSECURE_CONNECTION,
|
||||
crate::error::ErrorKind::User,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
Err(stream.throw_error(TlsRequired, None).await)?;
|
||||
}
|
||||
|
||||
// This log highlights the start of the connection.
|
||||
@@ -214,20 +197,21 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// no protocol extensions are supported.
|
||||
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
|
||||
let mut unsupported = vec![];
|
||||
for (k, _) in params.iter() {
|
||||
let mut supported = StartupMessageParams::default();
|
||||
|
||||
for (k, v) in params.iter() {
|
||||
if k.starts_with("_pq_.") {
|
||||
unsupported.push(k);
|
||||
} else {
|
||||
supported.insert(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove unsupported options so we don't send them to compute.
|
||||
|
||||
stream
|
||||
.write_message(&Be::NegotiateProtocolVersion {
|
||||
version: PG_PROTOCOL_LATEST,
|
||||
options: &unsupported,
|
||||
})
|
||||
.await?;
|
||||
stream.write_message(BeMessage::NegotiateProtocolVersion {
|
||||
version: PG_PROTOCOL_LATEST,
|
||||
options: &unsupported,
|
||||
});
|
||||
stream.flush().await?;
|
||||
|
||||
info!(
|
||||
?version,
|
||||
@@ -235,7 +219,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
session_type = "normal",
|
||||
"successful handshake; unsupported minor version requested"
|
||||
);
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
break Ok(HandshakeData::Startup(stream, supported));
|
||||
}
|
||||
FeStartupPacket::StartupMessage { version, params } => {
|
||||
warn!(
|
||||
|
||||
@@ -10,15 +10,14 @@ pub(crate) mod wake_compute;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures::FutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
@@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -38,6 +38,18 @@ use crate::{auth, compute};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
@@ -329,7 +341,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
@@ -349,10 +361,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return stream
|
||||
return Err(stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?;
|
||||
.await)?;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -365,7 +377,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: compute_user_info.clone(),
|
||||
@@ -377,22 +389,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
.await;
|
||||
|
||||
let node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
@@ -413,31 +422,28 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn prepare_client_connection(
|
||||
pub(crate) fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
cancel_key_data: CancelKeyData,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &node.delayed_notice {
|
||||
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
|
||||
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
for (name, value) in &node.params {
|
||||
stream.write_message_noflush(&Be::ParameterStatus {
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
})?;
|
||||
});
|
||||
}
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
stream.write_message(BeMessage::ReadyForQuery);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
|
||||
@@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ShouldRetryWakeCompute;
|
||||
use postgres_client::error::{DbError, SqlState};
|
||||
|
||||
use super::ShouldRetryWakeCompute;
|
||||
|
||||
#[test]
|
||||
fn should_retry_wake_compute_for_db_error() {
|
||||
// These SQLStates should NOT trigger a wake_compute retry.
|
||||
|
||||
@@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_client::tls::TlsConnect;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use super::*;
|
||||
@@ -49,15 +49,14 @@ async fn proxy_mitm(
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
assert!(buf.is_empty());
|
||||
let end_client = end_client.flush_and_into_inner().await.unwrap();
|
||||
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
|
||||
|
||||
// give the end_server the startup parameters
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(
|
||||
&postgres_protocol::message::frontend::StartupMessageParams {
|
||||
params: startup.params.into(),
|
||||
params: startup.params.as_bytes().into(),
|
||||
},
|
||||
&mut buf,
|
||||
)
|
||||
|
||||
@@ -128,7 +128,7 @@ trait TestAuth: Sized {
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -157,9 +157,7 @@ impl TestAuth for Scram {
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0, &RequestContext::test()))
|
||||
.await?
|
||||
let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
@@ -185,10 +183,12 @@ async fn dummy_proxy(
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
name: b"client_encoding",
|
||||
value: b"UTF8",
|
||||
});
|
||||
stream.write_message(BeMessage::ReadyForQuery);
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use core::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use pq_proto::CancelKeyData;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::pqproto::CancelKeyData;
|
||||
|
||||
pub trait CancellationPublisherMut: Send + Sync + 'static {
|
||||
#[allow(async_fn_in_trait)]
|
||||
async fn try_publish(
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
use std::io::ErrorKind;
|
||||
|
||||
use anyhow::Ok;
|
||||
use pq_proto::{CancelKeyData, id_to_cancel_key};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
|
||||
|
||||
pub mod keyspace {
|
||||
pub const CANCEL_PREFIX: &str = "cancel";
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum KeyPrefix {
|
||||
#[serde(untagged)]
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
@@ -18,9 +17,7 @@ impl KeyPrefix {
|
||||
pub(crate) fn build_redis_key(&self) -> String {
|
||||
match self {
|
||||
KeyPrefix::Cancel(key) => {
|
||||
let hi = (key.backend_pid as u64) << 32;
|
||||
let lo = (key.cancel_key as u64) & 0xffff_ffff;
|
||||
let id = hi | lo;
|
||||
let id = key.0.get();
|
||||
let keyspace = keyspace::CANCEL_PREFIX;
|
||||
format!("{keyspace}:{id:x}")
|
||||
}
|
||||
@@ -63,10 +60,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_build_redis_key() {
|
||||
let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData {
|
||||
backend_pid: 12345,
|
||||
cancel_key: 54321,
|
||||
});
|
||||
let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321));
|
||||
|
||||
let redis_key = cancel_key.build_redis_key();
|
||||
assert_eq!(redis_key, "cancel:30390000d431");
|
||||
@@ -77,10 +71,7 @@ mod tests {
|
||||
let redis_key = "cancel:30390000d431";
|
||||
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
|
||||
|
||||
let ref_key = CancelKeyData {
|
||||
backend_pid: 12345,
|
||||
cancel_key: 54321,
|
||||
};
|
||||
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
|
||||
|
||||
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
|
||||
let KeyPrefix::Cancel(cancel_key) = key;
|
||||
|
||||
@@ -2,11 +2,9 @@ use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::aio::PubSub;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
@@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate {
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct CancelSession {
|
||||
pub(crate) region_id: Option<String>,
|
||||
pub(crate) cancel_key_data: CancelKeyData,
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) peer_addr: Option<std::net::IpAddr>,
|
||||
}
|
||||
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
//! Definitions for SASL messages.
|
||||
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
|
||||
use crate::parse::split_cstr;
|
||||
|
||||
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
|
||||
@@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A single SASL message.
|
||||
/// This struct is deliberately decoupled from lower-level
|
||||
/// [`BeAuthenticationSaslMessage`].
|
||||
#[derive(Debug)]
|
||||
pub(super) enum ServerMessage<T> {
|
||||
/// We expect to see more steps.
|
||||
Continue(T),
|
||||
/// This is the final step.
|
||||
Final(T),
|
||||
}
|
||||
|
||||
impl<'a> ServerMessage<&'a str> {
|
||||
pub(super) fn to_reply(&self) -> BeMessage<'a> {
|
||||
BeMessage::AuthenticationSasl(match self {
|
||||
ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
|
||||
ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -14,7 +14,7 @@ use std::io;
|
||||
|
||||
pub(crate) use channel_binding::ChannelBinding;
|
||||
pub(crate) use messages::FirstMessage;
|
||||
pub(crate) use stream::{Outcome, SaslStream};
|
||||
pub(crate) use stream::{Outcome, authenticate};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
@@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
/// Fine-grained auth errors help in writing tests.
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum Error {
|
||||
#[error("Unsupported authentication method: {0}")]
|
||||
BadAuthMethod(Box<str>),
|
||||
|
||||
#[error("Channel binding failed: {0}")]
|
||||
ChannelBindingFailed(&'static str),
|
||||
|
||||
@@ -54,6 +57,7 @@ impl UserFacingError for Error {
|
||||
impl ReportableError for Error {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Error::BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
|
||||
@@ -3,61 +3,12 @@
|
||||
use std::io;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
use super::Mechanism;
|
||||
use super::messages::ServerMessage;
|
||||
use super::{Mechanism, Step};
|
||||
use crate::context::RequestContext;
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::stream::PqStream;
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub(crate) struct SaslStream<'a, S> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut PqStream<S>,
|
||||
/// Current password message we received from client.
|
||||
current: bytes::Bytes,
|
||||
/// First SASL message produced by client.
|
||||
first: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a, S> SaslStream<'a, S> {
|
||||
pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
first: Some(first),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
// Receive a new SASL message from the client.
|
||||
async fn recv(&mut self) -> io::Result<&str> {
|
||||
if let Some(first) = self.first.take() {
|
||||
return Ok(first);
|
||||
}
|
||||
|
||||
self.current = self.stream.read_password_message().await?;
|
||||
let s = std::str::from_utf8(&self.current)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
// Send a SASL message to the client.
|
||||
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message(&msg.to_reply()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Queue a SASL message for the client.
|
||||
fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message_noflush(&msg.to_reply())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// SASL authentication outcome.
|
||||
/// It's much easier to match on those two variants
|
||||
/// than to peek into a noisy protocol error type.
|
||||
@@ -69,33 +20,62 @@ pub(crate) enum Outcome<R> {
|
||||
Failure(&'static str),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
/// Perform SASL message exchange according to the underlying algorithm
|
||||
/// until user is either authenticated or denied access.
|
||||
pub(crate) async fn authenticate<M: Mechanism>(
|
||||
mut self,
|
||||
mut mechanism: M,
|
||||
) -> super::Result<Outcome<M::Output>> {
|
||||
loop {
|
||||
let input = self.recv().await?;
|
||||
let step = mechanism.exchange(input).map_err(|error| {
|
||||
info!(?error, "error during SASL exchange");
|
||||
error
|
||||
})?;
|
||||
pub async fn authenticate<S, F, M>(
|
||||
ctx: &RequestContext,
|
||||
stream: &mut PqStream<S>,
|
||||
mechanism: F,
|
||||
) -> super::Result<Outcome<M::Output>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
F: FnOnce(&str) -> super::Result<M>,
|
||||
M: Mechanism,
|
||||
{
|
||||
let sasl = {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
use super::Step;
|
||||
return Ok(match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
self.send(&ServerMessage::Continue(&reply)).await?;
|
||||
mechanism = moved_mechanism;
|
||||
continue;
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
self.send_noflush(&ServerMessage::Final(&reply))?;
|
||||
Outcome::Success(result)
|
||||
}
|
||||
Step::Failure(reason) => Outcome::Failure(reason),
|
||||
});
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = stream.read_password_message().await?;
|
||||
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
|
||||
};
|
||||
|
||||
let mut mechanism = mechanism(sasl.method)?;
|
||||
let mut input = sasl.message;
|
||||
loop {
|
||||
let step = mechanism
|
||||
.exchange(input)
|
||||
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
|
||||
|
||||
match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
mechanism = moved_mechanism;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
|
||||
// get next input
|
||||
stream.flush().await?;
|
||||
let msg = stream.read_password_message().await?;
|
||||
input = std::str::from_utf8(msg)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
// exit with success
|
||||
break Ok(Outcome::Success(result));
|
||||
}
|
||||
// exit with failure
|
||||
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState};
|
||||
use postgres_client::{
|
||||
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
|
||||
};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
@@ -41,6 +40,7 @@ use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::{NeonOptions, run_until_cancelled};
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
@@ -219,7 +219,7 @@ fn get_conn_info(
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParamsBuilder::default();
|
||||
let mut params = StartupMessageParams::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
|
||||
@@ -2,19 +2,17 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pqproto::{
|
||||
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
|
||||
read_message, read_startup,
|
||||
};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
@@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint;
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
pub(crate) framed: Framed<S>,
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
|
||||
#[cfg(test)]
|
||||
pub fn new_skip_handshake(stream: S) -> Self {
|
||||
Self {
|
||||
framed: Framed::new(stream),
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
self.framed.into_inner()
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub(crate) fn get_ref(&self) -> &S {
|
||||
self.framed.get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
fn err_connection() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper and read the first startup message.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
|
||||
let startup = read_startup(&mut stream).await?;
|
||||
Ok((
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
},
|
||||
startup,
|
||||
))
|
||||
}
|
||||
|
||||
/// Tell the client that encryption is not supported.
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// N for No.
|
||||
self.write.encryption(b'N');
|
||||
self.flush().await?;
|
||||
read_startup(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||
self.framed
|
||||
.read_startup_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
self.framed
|
||||
.read_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||
match self.read_message().await? {
|
||||
FeMessage::PasswordMessage(msg) => Ok(msg),
|
||||
bad => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("unexpected message type: {bad:?}"),
|
||||
)),
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
"incorrect message tag, expected {:?}, got {:?}",
|
||||
tag as char, actual_tag as char,
|
||||
)));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres password message, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: usize = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,6 +101,16 @@ pub struct ReportedError {
|
||||
error_kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportedError {
|
||||
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
|
||||
let error_kind = e.get_error_kind();
|
||||
Self {
|
||||
source: e.into(),
|
||||
error_kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ReportedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.source.fmt(f)
|
||||
@@ -102,109 +129,65 @@ impl ReportableError for ReportedError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
enum ErrorTag {
|
||||
#[serde(rename = "proxy")]
|
||||
Proxy,
|
||||
#[serde(rename = "compute")]
|
||||
Compute,
|
||||
#[serde(rename = "client")]
|
||||
Client,
|
||||
#[serde(rename = "controlplane")]
|
||||
ControlPlane,
|
||||
#[serde(rename = "other")]
|
||||
Other,
|
||||
}
|
||||
|
||||
impl From<ErrorKind> for ErrorTag {
|
||||
fn from(error_kind: ErrorKind) -> Self {
|
||||
match error_kind {
|
||||
ErrorKind::User => Self::Client,
|
||||
ErrorKind::ClientDisconnect => Self::Client,
|
||||
ErrorKind::RateLimit => Self::Proxy,
|
||||
ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI
|
||||
ErrorKind::Quota => Self::Proxy,
|
||||
ErrorKind::Service => Self::Proxy,
|
||||
ErrorKind::ControlPlane => Self::ControlPlane,
|
||||
ErrorKind::Postgres => Self::Other,
|
||||
ErrorKind::Compute => Self::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
struct ProbeErrorData {
|
||||
tag: ErrorTag,
|
||||
msg: String,
|
||||
cold_start_info: Option<ColdStartInfo>,
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub(crate) fn write_message_noflush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> io::Result<&mut Self> {
|
||||
self.framed
|
||||
.write_message(message)
|
||||
.map_err(ProtocolError::into_io_error)?;
|
||||
Ok(self)
|
||||
/// Tell the client that we are willing to accept SSL.
|
||||
/// This is not cancel safe
|
||||
pub async fn accept_tls(mut self) -> io::Result<S> {
|
||||
// S for SSL.
|
||||
self.write.encryption(b'S');
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer and flush it.
|
||||
pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
/// Assert that we are using direct TLS.
|
||||
pub fn accept_direct_tls(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(&mut self, message: BeMessage<'_>) {
|
||||
message.write_message(&mut self.write);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.framed.flush().await?;
|
||||
Ok(self)
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Writes message with the given error kind to the stream.
|
||||
/// Used only for probe queries
|
||||
async fn write_format_message(
|
||||
&mut self,
|
||||
msg: &str,
|
||||
error_kind: ErrorKind,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> String {
|
||||
let formatted_msg = match ctx {
|
||||
Some(ctx) if ctx.get_testodrome_id().is_some() => {
|
||||
serde_json::to_string(&ProbeErrorData {
|
||||
tag: ErrorTag::from(error_kind),
|
||||
msg: msg.to_string(),
|
||||
cold_start_info: Some(ctx.cold_start_info()),
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
_ => msg.to_string(),
|
||||
};
|
||||
|
||||
// already error case, ignore client IO error
|
||||
self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None))
|
||||
.await
|
||||
.inspect_err(|e| debug!("write_message failed: {e}"))
|
||||
.ok();
|
||||
|
||||
formatted_msg
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_format_message`], then re-throw it.
|
||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||
/// Write the error message to the client, then re-throw it.
|
||||
///
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub async fn throw_error_str<T>(
|
||||
pub(crate) async fn throw_error<E>(
|
||||
&mut self,
|
||||
msg: &'static str,
|
||||
error_kind: ErrorKind,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> Result<T, ReportedError> {
|
||||
self.write_format_message(msg, error_kind, ctx).await;
|
||||
) -> ReportedError
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
@@ -214,39 +197,39 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
);
|
||||
}
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(msg),
|
||||
error_kind,
|
||||
})
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_format_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub(crate) async fn throw_error<T, E>(
|
||||
&mut self,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> Result<T, ReportedError>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
self.write_format_message(&msg, error_kind, ctx).await;
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%error,
|
||||
msg,
|
||||
"forwarding error to user",
|
||||
);
|
||||
let probe_msg;
|
||||
let mut msg = &*msg;
|
||||
if let Some(ctx) = ctx {
|
||||
if ctx.get_testodrome_id().is_some() {
|
||||
let tag = match error_kind {
|
||||
ErrorKind::User => "client",
|
||||
ErrorKind::ClientDisconnect => "client",
|
||||
ErrorKind::RateLimit => "proxy",
|
||||
ErrorKind::ServiceRateLimit => "proxy",
|
||||
ErrorKind::Quota => "proxy",
|
||||
ErrorKind::Service => "proxy",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "other",
|
||||
ErrorKind::Compute => "compute",
|
||||
};
|
||||
probe_msg = typed_json::json!({
|
||||
"tag": tag,
|
||||
"msg": msg,
|
||||
"cold_start_info": ctx.cold_start_info(),
|
||||
})
|
||||
.to_string();
|
||||
msg = &probe_msg;
|
||||
}
|
||||
}
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(error),
|
||||
error_kind,
|
||||
})
|
||||
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
|
||||
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
|
||||
|
||||
self.flush()
|
||||
.await
|
||||
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
|
||||
|
||||
ReportedError::new(error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -287,6 +287,17 @@ def test_pgdata_import_smoke(
|
||||
with pytest.raises(psycopg2.errors.UndefinedTable):
|
||||
br_initdb_endpoint.safe_psql(f"select * from {workload.table}")
|
||||
|
||||
# The storage controller might be overly eager and attempt to finalize
|
||||
# the import before the task got a chance to exit.
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[
|
||||
".*Call to node.*management API.*failed.*Import task still running.*",
|
||||
]
|
||||
)
|
||||
|
||||
for ps in env.pageservers:
|
||||
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
|
||||
|
||||
|
||||
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
||||
def test_import_completion_on_restart(
|
||||
@@ -471,6 +482,17 @@ def test_import_respects_timeline_lifecycle(
|
||||
else:
|
||||
raise RuntimeError(f"{action} param not recognized")
|
||||
|
||||
# The storage controller might be overly eager and attempt to finalize
|
||||
# the import before the task got a chance to exit.
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[
|
||||
".*Call to node.*management API.*failed.*Import task still running.*",
|
||||
]
|
||||
)
|
||||
|
||||
for ps in env.pageservers:
|
||||
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
|
||||
|
||||
|
||||
@skip_in_debug_build("Validation query takes too long in debug builds")
|
||||
def test_import_chaos(
|
||||
|
||||
@@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
".*downloading failed, possibly for shutdown",
|
||||
# {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n'
|
||||
".*page_service.*will not become active.*",
|
||||
# the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state
|
||||
".*wal_connection_manager.*layer file download failed: No file found.*",
|
||||
".*wal_connection_manager.*could not ingest record.*",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -156,6 +159,45 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
env.pageservers[2].id: ("Detached", None),
|
||||
}
|
||||
|
||||
# Track all the attached locations with mode and generation
|
||||
history: list[tuple[int, str, int | None]] = []
|
||||
|
||||
def may_read(pageserver: NeonPageserver, mode: str, generation: int | None) -> bool:
|
||||
# Rules for when a pageserver may read:
|
||||
# - our generation is higher than any previous
|
||||
# - our generation is equal to previous, but no other pageserver
|
||||
# in that generation has been AttachedSingle (i.e. allowed to compact/GC)
|
||||
# - our generation is equal to previous, and the previous holder of this
|
||||
# generation was the same node as we're attaching now.
|
||||
#
|
||||
# If these conditions are not met, then a read _might_ work, but the pageserver might
|
||||
# also hit errors trying to download layers.
|
||||
highest_historic_generation = max([i[2] for i in history if i[2] is not None], default=None)
|
||||
|
||||
if generation is None:
|
||||
# We're not in an attached state, we may not read
|
||||
return False
|
||||
elif highest_historic_generation is not None and generation < highest_historic_generation:
|
||||
# We are in an outdated generation, we may not read
|
||||
return False
|
||||
elif highest_historic_generation is not None and generation == highest_historic_generation:
|
||||
# We are re-using a generation: if any pageserver other than this one
|
||||
# has held AttachedSingle mode, this node may not read (because some other
|
||||
# node may be doing GC/compaction).
|
||||
if any(
|
||||
i[1] == "AttachedSingle"
|
||||
and i[2] == highest_historic_generation
|
||||
and i[0] != pageserver.id
|
||||
for i in history
|
||||
):
|
||||
log.info(
|
||||
f"Skipping read on {pageserver.id} because other pageserver has been in AttachedSingle mode in generation {highest_historic_generation}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Fall through: we have passed conditions for readability
|
||||
return True
|
||||
|
||||
latest_attached = env.pageservers[0].id
|
||||
|
||||
for _i in range(0, 64):
|
||||
@@ -199,9 +241,10 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
assert len(tenants) == 1
|
||||
assert tenants[0]["generation"] == new_generation
|
||||
|
||||
log.info("Entering postgres...")
|
||||
workload.churn_rows(rng.randint(128, 256), pageserver.id)
|
||||
workload.validate(pageserver.id)
|
||||
if may_read(pageserver, last_state_ps[0], last_state_ps[1]):
|
||||
log.info("Entering postgres...")
|
||||
workload.churn_rows(rng.randint(128, 256), pageserver.id)
|
||||
workload.validate(pageserver.id)
|
||||
elif last_state_ps[0].startswith("Attached"):
|
||||
# The `storage_controller` will only re-attach on startup when a pageserver was the
|
||||
# holder of the latest generation: otherwise the pageserver will revert to detached
|
||||
@@ -241,18 +284,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
location_conf["generation"] = generation
|
||||
|
||||
pageserver.tenant_location_configure(tenant_id, location_conf)
|
||||
|
||||
last_state[pageserver.id] = (mode, generation)
|
||||
|
||||
# It's only valid to connect to the last generation. Newer generations may yank layer
|
||||
# files used in older generations.
|
||||
last_generation = max(
|
||||
[s[1] for s in last_state.values() if s[1] is not None], default=None
|
||||
)
|
||||
may_read_this_generation = may_read(pageserver, mode, generation)
|
||||
history.append((pageserver.id, mode, generation))
|
||||
|
||||
if mode.startswith("Attached") and generation == last_generation:
|
||||
# This is a basic test: we are validating that he endpoint works properly _between_
|
||||
# configuration changes. A stronger test would be to validate that clients see
|
||||
# no errors while we are making the changes.
|
||||
# This is a basic test: we are validating that he endpoint works properly _between_
|
||||
# configuration changes. A stronger test would be to validate that clients see
|
||||
# no errors while we are making the changes.
|
||||
if may_read_this_generation:
|
||||
workload.churn_rows(
|
||||
rng.randint(128, 256), pageserver.id, upload=mode != "AttachedStale"
|
||||
)
|
||||
@@ -265,9 +306,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
assert gc_summary["remote_storage_errors"] == 0
|
||||
assert gc_summary["indices_deleted"] > 0
|
||||
|
||||
# Attach all pageservers
|
||||
# Attach all pageservers, in a higher generation than any previous. We will use the same
|
||||
# gen for all, and AttachedMulti mode so that they do not interfere with one another.
|
||||
generation = env.storage_controller.attach_hook_issue(tenant_id, env.pageservers[0].id)
|
||||
for ps in env.pageservers:
|
||||
location_conf = {"mode": "AttachedMulti", "secondary_conf": None, "tenant_conf": {}}
|
||||
location_conf = {
|
||||
"mode": "AttachedMulti",
|
||||
"secondary_conf": None,
|
||||
"tenant_conf": {},
|
||||
"generation": generation,
|
||||
}
|
||||
ps.tenant_location_configure(tenant_id, location_conf)
|
||||
|
||||
# Confirm that all are readable
|
||||
|
||||
Reference in New Issue
Block a user