mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-17 05:00:38 +00:00
Compare commits
10 Commits
problame/p
...
build_info
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a769a43767 | ||
|
|
286f34dfce | ||
|
|
f290b27378 | ||
|
|
4cd18fcebd | ||
|
|
4c29e0594e | ||
|
|
3c56a4dd18 | ||
|
|
316309c85b | ||
|
|
e09bb9974c | ||
|
|
5289f341ce | ||
|
|
683ec2417c |
3
.github/workflows/build_and_test.yml
vendored
3
.github/workflows/build_and_test.yml
vendored
@@ -404,7 +404,7 @@ jobs:
|
||||
uses: ./.github/actions/save-coverage-data
|
||||
|
||||
regress-tests:
|
||||
needs: [ check-permissions, build-neon ]
|
||||
needs: [ check-permissions, build-neon, tag ]
|
||||
runs-on: [ self-hosted, gen3, large ]
|
||||
container:
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
|
||||
@@ -436,6 +436,7 @@ jobs:
|
||||
env:
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
|
||||
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
|
||||
BUILD_TAG: ${{ needs.tag.outputs.build-tag }}
|
||||
|
||||
- name: Merge and upload coverage data
|
||||
if: matrix.build_type == 'debug' && matrix.pg_version == 'v14'
|
||||
|
||||
15
Cargo.lock
generated
15
Cargo.lock
generated
@@ -1133,7 +1133,9 @@ dependencies = [
|
||||
"compute_api",
|
||||
"flate2",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hyper",
|
||||
"metrics",
|
||||
"notify",
|
||||
"num_cpus",
|
||||
"opentelemetry",
|
||||
@@ -2610,17 +2612,6 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nostarve_queue"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"rand 0.8.5",
|
||||
"scopeguard",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notify"
|
||||
version = "5.2.0"
|
||||
@@ -2962,7 +2953,6 @@ dependencies = [
|
||||
"itertools",
|
||||
"metrics",
|
||||
"nix 0.26.2",
|
||||
"nostarve_queue",
|
||||
"num-traits",
|
||||
"num_cpus",
|
||||
"once_cell",
|
||||
@@ -3517,6 +3507,7 @@ dependencies = [
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-native-tls",
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
|
||||
@@ -27,7 +27,6 @@ members = [
|
||||
"libs/postgres_ffi/wal_craft",
|
||||
"libs/vm_monitor",
|
||||
"libs/walproposer",
|
||||
"libs/nostarve_queue",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -38,7 +37,6 @@ license = "Apache-2.0"
|
||||
[workspace.dependencies]
|
||||
anyhow = { version = "1.0", features = ["backtrace"] }
|
||||
arc-swap = "1.6"
|
||||
async-channel = "1.9.0"
|
||||
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
|
||||
azure_core = "0.16"
|
||||
azure_identity = "0.16"
|
||||
@@ -193,7 +191,6 @@ tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
|
||||
utils = { version = "0.1", path = "./libs/utils/" }
|
||||
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
|
||||
walproposer = { version = "0.1", path = "./libs/walproposer/" }
|
||||
nostarve_queue = { path = "./libs/nostarve_queue" }
|
||||
|
||||
## Common library dependency
|
||||
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
|
||||
|
||||
@@ -714,6 +714,24 @@ RUN wget https://github.com/pksunkara/pgx_ulid/archive/refs/tags/v0.1.3.tar.gz -
|
||||
cargo pgrx install --release && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "wal2json-build"
|
||||
# Compile "wal2json" extension
|
||||
#
|
||||
#########################################################################################
|
||||
|
||||
FROM build-deps AS wal2json-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
ENV PATH "/usr/local/pgsql/bin/:$PATH"
|
||||
RUN wget https://github.com/eulerto/wal2json/archive/refs/tags/wal2json_2_5.tar.gz && \
|
||||
echo "b516653575541cf221b99cf3f8be9b6821f6dbcfc125675c85f35090f824f00e wal2json_2_5.tar.gz" | sha256sum --check && \
|
||||
mkdir wal2json-src && cd wal2json-src && tar xvzf ../wal2json_2_5.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/wal2json.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "neon-pg-ext-build"
|
||||
@@ -750,6 +768,7 @@ COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
|
||||
COPY pgxn/ pgxn/
|
||||
|
||||
RUN make -j $(getconf _NPROCESSORS_ONLN) \
|
||||
|
||||
@@ -12,8 +12,10 @@ cfg-if.workspace = true
|
||||
clap.workspace = true
|
||||
flate2.workspace = true
|
||||
futures.workspace = true
|
||||
git-version.workspace = true
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
notify.workspace = true
|
||||
metrics.workspace = true
|
||||
num_cpus.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
postgres.workspace = true
|
||||
|
||||
@@ -57,18 +57,27 @@ use compute_tools::logger::*;
|
||||
use compute_tools::monitor::launch_monitor;
|
||||
use compute_tools::params::*;
|
||||
use compute_tools::spec::*;
|
||||
use metrics::set_build_info_metric;
|
||||
use utils::{project_build_tag, project_git_version};
|
||||
|
||||
// this is an arbitrary build tag. Fine as a default / for testing purposes
|
||||
// in-case of not-set environment var
|
||||
const BUILD_TAG_DEFAULT: &str = "latest";
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
project_build_tag!(BUILD_TAG);
|
||||
|
||||
fn main() -> Result<()> {
|
||||
init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
|
||||
|
||||
let build_tag = option_env!("BUILD_TAG")
|
||||
.unwrap_or(BUILD_TAG_DEFAULT)
|
||||
.to_string();
|
||||
info!("build_tag: {build_tag}");
|
||||
|
||||
info!("Version: {GIT_VERSION}");
|
||||
info!("Build_tag: {BUILD_TAG}");
|
||||
|
||||
set_build_info_metric(GIT_VERSION, BUILD_TAG);
|
||||
|
||||
let matches = cli().get_matches();
|
||||
let pgbin_default = String::from("postgres");
|
||||
|
||||
@@ -11,7 +11,8 @@ use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIErr
|
||||
|
||||
use anyhow::Result;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, Request, Response, Server, StatusCode};
|
||||
use hyper::{header::CONTENT_TYPE, Body, Method, Request, Response, Server, StatusCode};
|
||||
use metrics::{Encoder, TextEncoder};
|
||||
use num_cpus;
|
||||
use serde_json;
|
||||
use tokio::task;
|
||||
@@ -51,6 +52,20 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
Response::new(Body::from(serde_json::to_string(&status_response).unwrap()))
|
||||
}
|
||||
|
||||
// prometheus metrics
|
||||
(&Method::GET, "/metrics") => {
|
||||
let mut buffer = vec![];
|
||||
let metrics = metrics::gather();
|
||||
let encoder = TextEncoder::new();
|
||||
encoder.encode(&metrics, &mut buffer).unwrap();
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, encoder.format_type())
|
||||
.body(Body::from(buffer))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
// Startup metrics in JSON format. Keep /metrics reserved for a possible
|
||||
// future use for Prometheus metrics format.
|
||||
(&Method::GET, "/metrics.json") => {
|
||||
|
||||
@@ -687,6 +687,9 @@ pub fn handle_extension_neon(client: &mut Client) -> Result<()> {
|
||||
info!("create neon extension with query: {}", query);
|
||||
client.simple_query(query)?;
|
||||
|
||||
query = "UPDATE pg_extension SET extrelocatable = true WHERE extname = 'neon'";
|
||||
client.simple_query(query)?;
|
||||
|
||||
query = "ALTER EXTENSION neon SET SCHEMA neon";
|
||||
info!("alter neon extension schema with query: {}", query);
|
||||
client.simple_query(query)?;
|
||||
|
||||
@@ -21,7 +21,7 @@ use pageserver_api::models::{
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use postgres_backend::AuthType;
|
||||
use postgres_connection::{parse_host_port, PgConnectionConfig};
|
||||
use reqwest::blocking::{Client, ClientBuilder, RequestBuilder, Response};
|
||||
use reqwest::blocking::{Client, RequestBuilder, Response};
|
||||
use reqwest::{IntoUrl, Method};
|
||||
use thiserror::Error;
|
||||
use utils::auth::{Claims, Scope};
|
||||
@@ -99,7 +99,7 @@ impl PageServerNode {
|
||||
pg_connection_config: PgConnectionConfig::new_host_port(host, port),
|
||||
conf: conf.clone(),
|
||||
env: env.clone(),
|
||||
http_client: ClientBuilder::new().timeout(None).build().unwrap(),
|
||||
http_client: Client::new(),
|
||||
http_base_url: format!("http://{}/v1", conf.listen_http_addr),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
[package]
|
||||
name = "nostarve_queue"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
scopeguard.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
futures.workspace = true
|
||||
rand.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time"] }
|
||||
@@ -1,316 +0,0 @@
|
||||
//! Synchronization primitive to prevent starvation among concurrent tasks that do the same work.
|
||||
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
fmt,
|
||||
future::poll_fn,
|
||||
sync::Mutex,
|
||||
task::{Poll, Waker},
|
||||
};
|
||||
|
||||
pub struct Queue<T> {
|
||||
inner: Mutex<Inner<T>>,
|
||||
}
|
||||
|
||||
struct Inner<T> {
|
||||
waiters: VecDeque<usize>,
|
||||
free: VecDeque<usize>,
|
||||
slots: Vec<Option<(Option<Waker>, Option<T>)>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Position<'q, T> {
|
||||
idx: usize,
|
||||
queue: &'q Queue<T>,
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for Position<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Position").field("idx", &self.idx).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Inner<T> {
|
||||
#[cfg(not(test))]
|
||||
#[inline]
|
||||
fn integrity_check(&self) {}
|
||||
|
||||
#[cfg(test)]
|
||||
fn integrity_check(&self) {
|
||||
use std::collections::HashSet;
|
||||
let waiters = self.waiters.iter().copied().collect::<HashSet<_>>();
|
||||
let free = self.free.iter().copied().collect::<HashSet<_>>();
|
||||
for (slot_idx, slot) in self.slots.iter().enumerate() {
|
||||
match slot {
|
||||
None => {
|
||||
assert!(!waiters.contains(&slot_idx));
|
||||
assert!(free.contains(&slot_idx));
|
||||
}
|
||||
Some((None, None)) => {
|
||||
assert!(waiters.contains(&slot_idx));
|
||||
assert!(!free.contains(&slot_idx));
|
||||
}
|
||||
Some((Some(_), Some(_))) => {
|
||||
assert!(!waiters.contains(&slot_idx));
|
||||
assert!(!free.contains(&slot_idx));
|
||||
}
|
||||
Some((Some(_), None)) => {
|
||||
assert!(waiters.contains(&slot_idx));
|
||||
assert!(!free.contains(&slot_idx));
|
||||
}
|
||||
Some((None, Some(_))) => {
|
||||
assert!(!waiters.contains(&slot_idx));
|
||||
assert!(!free.contains(&slot_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Queue<T> {
|
||||
pub fn new(size: usize) -> Self {
|
||||
Queue {
|
||||
inner: Mutex::new(Inner {
|
||||
waiters: VecDeque::new(),
|
||||
free: (0..size).collect(),
|
||||
slots: {
|
||||
let mut v = Vec::with_capacity(size);
|
||||
v.resize_with(size, || None);
|
||||
v
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
pub fn begin(&self) -> Result<Position<T>, ()> {
|
||||
#[cfg(test)]
|
||||
tracing::trace!("get in line locking inner");
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.integrity_check();
|
||||
let my_waitslot_idx = inner
|
||||
.free
|
||||
.pop_front()
|
||||
.expect("can't happen, len(slots) = len(waiters");
|
||||
inner.waiters.push_back(my_waitslot_idx);
|
||||
let prev = inner.slots[my_waitslot_idx].replace((None, None));
|
||||
assert!(prev.is_none());
|
||||
inner.integrity_check();
|
||||
Ok(Position {
|
||||
idx: my_waitslot_idx,
|
||||
queue: &self,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q, T> Position<'q, T> {
|
||||
pub fn complete_and_wait(self, datum: T) -> impl std::future::Future<Output = T> + 'q {
|
||||
#[cfg(test)]
|
||||
tracing::trace!("found victim locking waiters");
|
||||
let mut inner = self.queue.inner.lock().unwrap();
|
||||
inner.integrity_check();
|
||||
let winner_idx = inner.waiters.pop_front().expect("we put ourselves in");
|
||||
#[cfg(test)]
|
||||
tracing::trace!(winner_idx, "putting victim into next waiters slot");
|
||||
let winner_slot = inner.slots[winner_idx].as_mut().unwrap();
|
||||
let prev = winner_slot.1.replace(datum);
|
||||
assert!(
|
||||
prev.is_none(),
|
||||
"ensure we didn't mess up this simple ring buffer structure"
|
||||
);
|
||||
if let Some(waker) = winner_slot.0.take() {
|
||||
#[cfg(test)]
|
||||
tracing::trace!(winner_idx, "waking up winner");
|
||||
waker.wake()
|
||||
}
|
||||
inner.integrity_check();
|
||||
drop(inner); // the poll_fn locks it again
|
||||
|
||||
let mut poll_num = 0;
|
||||
let mut drop_guard = Some(scopeguard::guard((), |()| {
|
||||
panic!("must not drop this future until Ready");
|
||||
}));
|
||||
|
||||
// take the victim that was found by someone else
|
||||
poll_fn(move |cx| {
|
||||
let my_waitslot_idx = self.idx;
|
||||
poll_num += 1;
|
||||
#[cfg(test)]
|
||||
tracing::trace!(poll_num, "poll_fn locking waiters");
|
||||
let mut inner = self.queue.inner.lock().unwrap();
|
||||
inner.integrity_check();
|
||||
let my_waitslot = inner.slots[self.idx].as_mut().unwrap();
|
||||
// assert!(
|
||||
// poll_num <= 2,
|
||||
// "once we place the waker in the slot, next wakeup should have a result: {}",
|
||||
// my_waitslot.1.is_some()
|
||||
// );
|
||||
if let Some(res) = my_waitslot.1.take() {
|
||||
#[cfg(test)]
|
||||
tracing::trace!(poll_num, "have cache slot");
|
||||
// above .take() resets the waiters slot to None
|
||||
debug_assert!(my_waitslot.0.is_none());
|
||||
debug_assert!(my_waitslot.1.is_none());
|
||||
inner.slots[my_waitslot_idx] = None;
|
||||
inner.free.push_back(my_waitslot_idx);
|
||||
let _ = scopeguard::ScopeGuard::into_inner(drop_guard.take().unwrap());
|
||||
inner.integrity_check();
|
||||
return Poll::Ready(res);
|
||||
}
|
||||
// assert_eq!(poll_num, 1);
|
||||
if !my_waitslot
|
||||
.0
|
||||
.as_ref()
|
||||
.map(|existing| cx.waker().will_wake(existing))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let prev = my_waitslot.0.replace(cx.waker().clone());
|
||||
#[cfg(test)]
|
||||
tracing::trace!(poll_num, prev_is_some = prev.is_some(), "updating waker");
|
||||
}
|
||||
inner.integrity_check();
|
||||
#[cfg(test)]
|
||||
tracing::trace!(poll_num, "waiting to be woken up");
|
||||
Poll::Pending
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::Poll,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use rand::RngCore;
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_order_completion_and_wait() {
|
||||
let queue = super::Queue::new(2);
|
||||
|
||||
let q1 = queue.begin().unwrap();
|
||||
let q2 = queue.begin().unwrap();
|
||||
|
||||
assert_eq!(q1.complete_and_wait(23).await, 23);
|
||||
assert_eq!(q2.complete_and_wait(42).await, 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn out_of_order_completion_and_wait() {
|
||||
let queue = super::Queue::new(2);
|
||||
|
||||
let q1 = queue.begin().unwrap();
|
||||
let q2 = queue.begin().unwrap();
|
||||
|
||||
let mut q2compfut = q2.complete_and_wait(23);
|
||||
|
||||
match futures::poll!(&mut q2compfut) {
|
||||
Poll::Pending => {}
|
||||
Poll::Ready(_) => panic!("should not be ready yet, it's queued after q1"),
|
||||
}
|
||||
|
||||
let q1res = q1.complete_and_wait(42).await;
|
||||
assert_eq!(q1res, 23);
|
||||
|
||||
let q2res = q2compfut.await;
|
||||
assert_eq!(q2res, 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_order_completion_out_of_order_wait() {
|
||||
let queue = super::Queue::new(2);
|
||||
|
||||
let q1 = queue.begin().unwrap();
|
||||
let q2 = queue.begin().unwrap();
|
||||
|
||||
let mut q1compfut = q1.complete_and_wait(23);
|
||||
|
||||
let mut q2compfut = q2.complete_and_wait(42);
|
||||
|
||||
match futures::poll!(&mut q2compfut) {
|
||||
Poll::Pending => {
|
||||
unreachable!("q2 should be ready, it wasn't first but q1 is serviced already")
|
||||
}
|
||||
Poll::Ready(x) => assert_eq!(x, 42),
|
||||
}
|
||||
|
||||
assert_eq!(futures::poll!(&mut q1compfut), Poll::Ready(23));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn stress() {
|
||||
let ntasks = 8;
|
||||
let queue_size = 8;
|
||||
let queue = Arc::new(super::Queue::new(queue_size));
|
||||
|
||||
let stop = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let mut tasks = vec![];
|
||||
for i in 0..ntasks {
|
||||
let jh = tokio::spawn({
|
||||
let queue = Arc::clone(&queue);
|
||||
let stop = Arc::clone(&stop);
|
||||
async move {
|
||||
while !stop.load(Ordering::Relaxed) {
|
||||
let q = queue.begin().unwrap();
|
||||
for _ in 0..(rand::thread_rng().next_u32() % 10_000) {
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
q.complete_and_wait(i).await;
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
});
|
||||
tasks.push(jh);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
|
||||
stop.store(true, Ordering::Relaxed);
|
||||
|
||||
for t in tasks {
|
||||
t.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_two_runtimes_shared_queue() {
|
||||
std::thread::scope(|s| {
|
||||
let ntasks = 8;
|
||||
let queue_size = 8;
|
||||
let queue = Arc::new(super::Queue::new(queue_size));
|
||||
|
||||
let stop = Arc::new(AtomicBool::new(false));
|
||||
|
||||
for i in 0..ntasks {
|
||||
s.spawn({
|
||||
let queue = Arc::clone(&queue);
|
||||
let stop = Arc::clone(&stop);
|
||||
move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
rt.block_on(async move {
|
||||
while !stop.load(Ordering::Relaxed) {
|
||||
let q = queue.begin().unwrap();
|
||||
for _ in 0..(rand::thread_rng().next_u32() % 10_000) {
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
q.complete_and_wait(i).await;
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::thread::sleep(Duration::from_secs(10));
|
||||
|
||||
stop.store(true, Ordering::Relaxed);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,6 @@ humantime-serde.workspace = true
|
||||
hyper.workspace = true
|
||||
itertools.workspace = true
|
||||
nix.workspace = true
|
||||
nostarve_queue.workspace = true
|
||||
# hack to get the number of worker threads tokio uses
|
||||
num_cpus = { version = "1.15" }
|
||||
num-traits.workspace = true
|
||||
|
||||
@@ -314,6 +314,7 @@ static PAGE_CACHE_ERRORS: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
#[strum(serialize_all = "kebab_case")]
|
||||
pub(crate) enum PageCacheErrorKind {
|
||||
AcquirePinnedSlotTimeout,
|
||||
EvictIterLimit,
|
||||
}
|
||||
|
||||
pub(crate) fn page_cache_errors_inc(error_kind: PageCacheErrorKind) {
|
||||
|
||||
@@ -83,7 +83,6 @@ use std::{
|
||||
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::OnceCell;
|
||||
use tracing::instrument;
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
@@ -253,9 +252,6 @@ pub struct PageCache {
|
||||
next_evict_slot: AtomicUsize,
|
||||
|
||||
size_metrics: &'static PageCacheSizeMetrics,
|
||||
|
||||
find_victim_waiters:
|
||||
nostarve_queue::Queue<(usize, tokio::sync::RwLockWriteGuard<'static, SlotInner>)>,
|
||||
}
|
||||
|
||||
struct PinnedSlotsPermit(tokio::sync::OwnedSemaphorePermit);
|
||||
@@ -434,9 +430,8 @@ impl PageCache {
|
||||
///
|
||||
/// Store an image of the given page in the cache.
|
||||
///
|
||||
#[cfg_attr(test, instrument(skip_all, level = "trace", fields(%key, %lsn)))]
|
||||
pub async fn memorize_materialized_page(
|
||||
&'static self,
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
key: Key,
|
||||
@@ -527,9 +522,8 @@ impl PageCache {
|
||||
|
||||
// Section 1.2: Public interface functions for working with immutable file pages.
|
||||
|
||||
#[cfg_attr(test, instrument(skip_all, level = "trace", fields(?file_id, ?blkno)))]
|
||||
pub async fn read_immutable_buf(
|
||||
&'static self,
|
||||
&self,
|
||||
file_id: FileId,
|
||||
blkno: u32,
|
||||
ctx: &RequestContext,
|
||||
@@ -635,7 +629,7 @@ impl PageCache {
|
||||
/// ```
|
||||
///
|
||||
async fn lock_for_read(
|
||||
&'static self,
|
||||
&self,
|
||||
cache_key: &mut CacheKey,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<ReadBufResult> {
|
||||
@@ -857,15 +851,10 @@ impl PageCache {
|
||||
///
|
||||
/// On return, the slot is empty and write-locked.
|
||||
async fn find_victim(
|
||||
&'static self,
|
||||
&self,
|
||||
_permit_witness: &PinnedSlotsPermit,
|
||||
) -> anyhow::Result<(usize, tokio::sync::RwLockWriteGuard<SlotInner>)> {
|
||||
let nostarve_position = self.find_victim_waiters.begin()
|
||||
.expect("we initialize the nostarve queue to the same size as the slots semaphore, and the caller is presenting a permit");
|
||||
|
||||
let span = tracing::info_span!("find_victim", ?nostarve_position);
|
||||
let _enter = span.enter();
|
||||
|
||||
let iter_limit = self.slots.len() * 10;
|
||||
let mut iters = 0;
|
||||
loop {
|
||||
iters += 1;
|
||||
@@ -877,8 +866,41 @@ impl PageCache {
|
||||
let mut inner = match slot.inner.try_write() {
|
||||
Ok(inner) => inner,
|
||||
Err(_err) => {
|
||||
if iters > self.slots.len() * (MAX_USAGE_COUNT as usize) {
|
||||
unreachable!("find_victim_waiters prevents starvation");
|
||||
if iters > iter_limit {
|
||||
// NB: Even with the permits, there's no hard guarantee that we will find a slot with
|
||||
// any particular number of iterations: other threads might race ahead and acquire and
|
||||
// release pins just as we're scanning the array.
|
||||
//
|
||||
// Imagine that nslots is 2, and as starting point, usage_count==1 on all
|
||||
// slots. There are two threads running concurrently, A and B. A has just
|
||||
// acquired the permit from the semaphore.
|
||||
//
|
||||
// A: Look at slot 1. Its usage_count == 1, so decrement it to zero, and continue the search
|
||||
// B: Acquire permit.
|
||||
// B: Look at slot 2, decrement its usage_count to zero and continue the search
|
||||
// B: Look at slot 1. Its usage_count is zero, so pin it and bump up its usage_count to 1.
|
||||
// B: Release pin and permit again
|
||||
// B: Acquire permit.
|
||||
// B: Look at slot 2. Its usage_count is zero, so pin it and bump up its usage_count to 1.
|
||||
// B: Release pin and permit again
|
||||
//
|
||||
// Now we're back in the starting situation that both slots have
|
||||
// usage_count 1, but A has now been through one iteration of the
|
||||
// find_victim() loop. This can repeat indefinitely and on each
|
||||
// iteration, A's iteration count increases by one.
|
||||
//
|
||||
// So, even though the semaphore for the permits is fair, the victim search
|
||||
// itself happens in parallel and is not fair.
|
||||
// Hence even with a permit, a task can theoretically be starved.
|
||||
// To avoid this, we'd need tokio to give priority to tasks that are holding
|
||||
// permits for longer.
|
||||
// Note that just yielding to tokio during iteration without such
|
||||
// priority boosting is likely counter-productive. We'd just give more opportunities
|
||||
// for B to bump usage count, further starving A.
|
||||
crate::metrics::page_cache_errors_inc(
|
||||
crate::metrics::PageCacheErrorKind::EvictIterLimit,
|
||||
);
|
||||
anyhow::bail!("exceeded evict iter limit");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -889,8 +911,7 @@ impl PageCache {
|
||||
inner.key = None;
|
||||
}
|
||||
crate::metrics::PAGE_CACHE_FIND_VICTIMS_ITERS_TOTAL.inc_by(iters as u64);
|
||||
|
||||
return Ok(nostarve_position.complete_and_wait((slot_idx, inner)).await);
|
||||
return Ok((slot_idx, inner));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -934,7 +955,6 @@ impl PageCache {
|
||||
next_evict_slot: AtomicUsize::new(0),
|
||||
size_metrics,
|
||||
pinned_slots: Arc::new(tokio::sync::Semaphore::new(num_pages)),
|
||||
find_victim_waiters: ::nostarve_queue::Queue::new(num_pages),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2908,7 +2908,7 @@ impl Tenant {
|
||||
};
|
||||
// create a `tenant/{tenant_id}/timelines/basebackup-{timeline_id}.{TEMP_FILE_SUFFIX}/`
|
||||
// temporary directory for basebackup files for the given timeline.
|
||||
let initdb_path = path_with_suffix_extension(
|
||||
let pgdata_path = path_with_suffix_extension(
|
||||
self.conf
|
||||
.timelines_path(&self.tenant_id)
|
||||
.join(format!("basebackup-{timeline_id}")),
|
||||
@@ -2917,26 +2917,25 @@ impl Tenant {
|
||||
|
||||
// an uninit mark was placed before, nothing else can access this timeline files
|
||||
// current initdb was not run yet, so remove whatever was left from the previous runs
|
||||
if initdb_path.exists() {
|
||||
fs::remove_dir_all(&initdb_path).with_context(|| {
|
||||
format!("Failed to remove already existing initdb directory: {initdb_path}")
|
||||
if pgdata_path.exists() {
|
||||
fs::remove_dir_all(&pgdata_path).with_context(|| {
|
||||
format!("Failed to remove already existing initdb directory: {pgdata_path}")
|
||||
})?;
|
||||
}
|
||||
// Init temporarily repo to get bootstrap data, this creates a directory in the `initdb_path` path
|
||||
run_initdb(self.conf, &initdb_path, pg_version)?;
|
||||
// Init temporarily repo to get bootstrap data, this creates a directory in the `pgdata_path` path
|
||||
run_initdb(self.conf, &pgdata_path, pg_version)?;
|
||||
// this new directory is very temporary, set to remove it immediately after bootstrap, we don't need it
|
||||
scopeguard::defer! {
|
||||
if let Err(e) = fs::remove_dir_all(&initdb_path) {
|
||||
if let Err(e) = fs::remove_dir_all(&pgdata_path) {
|
||||
// this is unlikely, but we will remove the directory on pageserver restart or another bootstrap call
|
||||
error!("Failed to remove temporary initdb directory '{initdb_path}': {e}");
|
||||
error!("Failed to remove temporary initdb directory '{pgdata_path}': {e}");
|
||||
}
|
||||
}
|
||||
let pgdata_path = &initdb_path;
|
||||
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(pgdata_path)?.align();
|
||||
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(&pgdata_path)?.align();
|
||||
|
||||
// Upload the created data dir to S3
|
||||
if let Some(storage) = &self.remote_storage {
|
||||
let pgdata_zstd = import_datadir::create_tar_zst(pgdata_path).await?;
|
||||
let pgdata_zstd = import_datadir::create_tar_zst(&pgdata_path).await?;
|
||||
let pgdata_zstd = Bytes::from(pgdata_zstd);
|
||||
backoff::retry(
|
||||
|| async {
|
||||
@@ -2986,7 +2985,7 @@ impl Tenant {
|
||||
|
||||
import_datadir::import_timeline_from_postgres_datadir(
|
||||
unfinished_timeline,
|
||||
pgdata_path,
|
||||
&pgdata_path,
|
||||
pgdata_lsn,
|
||||
ctx,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "storage/buf_internals.h"
|
||||
#include "storage/lwlock.h"
|
||||
#include "storage/ipc.h"
|
||||
#include "storage/pg_shmem.h"
|
||||
#include "c.h"
|
||||
#include "postmaster/interrupt.h"
|
||||
|
||||
@@ -87,6 +88,12 @@ bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) =
|
||||
static bool pageserver_flush(void);
|
||||
static void pageserver_disconnect(void);
|
||||
|
||||
static bool
|
||||
PagestoreShmemIsValid()
|
||||
{
|
||||
return pagestore_shared && UsedShmemSegAddr;
|
||||
}
|
||||
|
||||
static bool
|
||||
CheckPageserverConnstring(char **newval, void **extra, GucSource source)
|
||||
{
|
||||
@@ -96,7 +103,7 @@ CheckPageserverConnstring(char **newval, void **extra, GucSource source)
|
||||
static void
|
||||
AssignPageserverConnstring(const char *newval, void *extra)
|
||||
{
|
||||
if(!pagestore_shared)
|
||||
if(!PagestoreShmemIsValid())
|
||||
return;
|
||||
LWLockAcquire(pagestore_shared->lock, LW_EXCLUSIVE);
|
||||
strlcpy(pagestore_shared->pageserver_connstring, newval, MAX_PAGESERVER_CONNSTRING_SIZE);
|
||||
@@ -107,7 +114,7 @@ AssignPageserverConnstring(const char *newval, void *extra)
|
||||
static bool
|
||||
CheckConnstringUpdated()
|
||||
{
|
||||
if(!pagestore_shared)
|
||||
if(!PagestoreShmemIsValid())
|
||||
return false;
|
||||
return pagestore_local_counter < pg_atomic_read_u64(&pagestore_shared->update_counter);
|
||||
}
|
||||
@@ -115,7 +122,7 @@ CheckConnstringUpdated()
|
||||
static void
|
||||
ReloadConnstring()
|
||||
{
|
||||
if(!pagestore_shared)
|
||||
if(!PagestoreShmemIsValid())
|
||||
return;
|
||||
LWLockAcquire(pagestore_shared->lock, LW_SHARED);
|
||||
strlcpy(local_pageserver_connstring, pagestore_shared->pageserver_connstring, sizeof(local_pageserver_connstring));
|
||||
|
||||
@@ -2,3 +2,4 @@
|
||||
comment = 'cloud storage for PostgreSQL'
|
||||
default_version = '1.1'
|
||||
module_pathname = '$libdir/neon'
|
||||
relocatable = true
|
||||
|
||||
@@ -76,3 +76,4 @@ tokio-util.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
config::AuthenticationConfig,
|
||||
@@ -131,7 +132,7 @@ async fn auth_quirks_creds(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
@@ -165,7 +166,7 @@ async fn auth_quirks(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
@@ -241,7 +242,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
pub async fn authenticate(
|
||||
&mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{
|
||||
console::{self, AuthInfo, ConsoleReqExtra},
|
||||
proxy::LatencyTimer,
|
||||
sasl, scram,
|
||||
stream::PqStream,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -15,7 +15,7 @@ pub(super) async fn authenticate(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::{AuthSuccess, ComputeCredentials};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
proxy::LatencyTimer,
|
||||
stream,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -12,7 +12,7 @@ use tracing::{info, warn};
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn cleartext_hack(
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
@@ -37,7 +37,7 @@ pub async fn cleartext_hack(
|
||||
/// Very similar to [`cleartext_hack`], but there's a specific password format.
|
||||
pub async fn password_hack(
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{sasl, scram, stream::PqStream};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
sasl, scram,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self) -> BeMessage<'_>;
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
@@ -21,8 +26,14 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
|
||||
if channel_binding {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
} else {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
scram::METHODS_WITHOUT_PLUS,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +43,7 @@ pub struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
@@ -43,37 +54,44 @@ pub struct CleartextPassword;
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
pub struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream>,
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<S>) -> Self {
|
||||
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream.write_message(&method.first_message()).await?;
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
tls_server_end_point: self.tls_server_end_point,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -123,9 +141,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let secret = self.state.0;
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(secret, rand::random, None))
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
self.tls_server_end_point,
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(outcome)
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
@@ -65,7 +67,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
|
||||
|
||||
// Configure TLS
|
||||
let tls_config: Arc<rustls::ServerConfig> = match (
|
||||
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
@@ -89,16 +91,22 @@ async fn main() -> anyhow::Result<()> {
|
||||
))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
.collect_vec()
|
||||
};
|
||||
|
||||
rustls::ServerConfig::builder()
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into()
|
||||
.into();
|
||||
|
||||
(tls_config, tls_server_end_point)
|
||||
}
|
||||
_ => bail!("tls-key and tls-cert must be specified"),
|
||||
};
|
||||
@@ -113,6 +121,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let main = tokio::spawn(task_main(
|
||||
Arc::new(destination),
|
||||
tls_config,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
@@ -134,6 +143,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -159,7 +169,7 @@ async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
handle_client(dest_suffix, tls_config, socket).await
|
||||
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -207,6 +217,7 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
@@ -231,7 +242,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
Ok(raw.upgrade(tls_config).await?)
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(raw.upgrade(tls_config).await?),
|
||||
tls_server_end_point,
|
||||
})
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
@@ -246,9 +261,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let tls_stream = ssl_handshake(stream, tls_config).await?;
|
||||
let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use crate::auth;
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use rustls::sign;
|
||||
use rustls::{sign, Certificate, PrivateKey};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
@@ -27,6 +30,7 @@ pub struct MetricCollectionConfig {
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: Option<HashSet<String>>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
pub struct HttpConfig {
|
||||
@@ -52,7 +56,7 @@ pub fn configure_tls(
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert(key_path, cert_path, true)?;
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
@@ -64,7 +68,7 @@ pub fn configure_tls(
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert(
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
@@ -76,35 +80,97 @@ pub fn configure_tls(
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(Arc::new(cert_resolver))
|
||||
.with_cert_resolver(cert_resolver.clone())
|
||||
.into();
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
common_names: Some(common_names),
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
struct CertResolver {
|
||||
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
|
||||
default: Option<Arc<rustls::sign::CertifiedKey>>,
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
certs: HashMap::new(),
|
||||
default: None,
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(&cert.0)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] =
|
||||
Sha256::new().chain_update(&cert.0).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_cert(
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
@@ -120,57 +186,65 @@ impl CertResolver {
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
};
|
||||
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context(format!(
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
)
|
||||
})?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
};
|
||||
|
||||
let common_name = {
|
||||
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
|
||||
.context(format!(
|
||||
"Failed to parse PEM object from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.1;
|
||||
let common_name = pem.parse_x509()?.subject().to_string();
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKey,
|
||||
cert_chain: Vec<Certificate>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
let common_name = if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
.context(format!(
|
||||
"Failed to parse common name from certificate at '{cert_path}'."
|
||||
))?;
|
||||
.context("Failed to parse common name from certificate")?;
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some(cert.clone());
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, cert);
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_common_names(&self) -> HashSet<String> {
|
||||
pub fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
@@ -178,15 +252,24 @@ impl CertResolver {
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
_client_hello: rustls::server::ClientHello,
|
||||
client_hello: rustls::server::ClientHello,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = _client_hello.server_name() {
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
|
||||
@@ -470,7 +470,17 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
|
||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.context("missing certificate")?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
@@ -875,7 +885,7 @@ pub async fn proxy_pass(
|
||||
/// Thin connection context.
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
stream: PqStream<Stream<S>>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
@@ -889,7 +899,7 @@ struct Client<'a, S> {
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
stream: PqStream<Stream<S>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
//! A group of high-level tests for connection establishing logic and auth.
|
||||
//!
|
||||
|
||||
mod mitm;
|
||||
|
||||
use super::*;
|
||||
use crate::auth::backend::TestBackend;
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{CachedNodeInfo, NodeInfo};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
common_name: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
@@ -21,7 +25,15 @@ fn generate_certs(
|
||||
params
|
||||
})?;
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
|
||||
let cert = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
|
||||
params.distinguished_name = rcgen::DistinguishedName::new();
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, common_name);
|
||||
params
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
rustls::Certificate(ca.serialize_der()?),
|
||||
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
|
||||
@@ -37,7 +49,14 @@ struct ClientConfig<'a> {
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
self,
|
||||
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
|
||||
) -> anyhow::Result<
|
||||
impl tokio_postgres::tls::TlsConnect<
|
||||
S,
|
||||
Error = impl std::fmt::Debug,
|
||||
Future = impl Send,
|
||||
Stream = RustlsStream<S>,
|
||||
>,
|
||||
> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
@@ -49,20 +68,24 @@ fn generate_tls_config<'a>(
|
||||
hostname: &'a str,
|
||||
common_name: &'a str,
|
||||
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
|
||||
let (ca, cert, key) = generate_certs(hostname)?;
|
||||
let (ca, cert, key) = generate_certs(hostname, common_name)?;
|
||||
|
||||
let tls_config = {
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert], key)?
|
||||
.with_single_cert(vec![cert.clone()], key.clone())?
|
||||
.into();
|
||||
|
||||
let common_names = Some([common_name.to_owned()].iter().cloned().collect());
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
let common_names = Some(cert_resolver.get_common_names());
|
||||
|
||||
TlsConfig {
|
||||
config,
|
||||
common_names,
|
||||
cert_resolver: Arc::new(cert_resolver),
|
||||
}
|
||||
};
|
||||
|
||||
@@ -253,6 +276,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
@@ -263,6 +287,30 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
257
proxy/src/proxy/tests/mitm.rs
Normal file
257
proxy/src/proxy/tests/mitm.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Man-in-the-middle tests
|
||||
//!
|
||||
//! Channel binding should prevent a proxy server
|
||||
//! - that has access to create valid certificates -
|
||||
//! from controlling the TLS connection.
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::*;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
enum Intercept {
|
||||
None,
|
||||
Methods,
|
||||
SASLResponse,
|
||||
}
|
||||
|
||||
async fn proxy_mitm(
|
||||
intercept: Intercept,
|
||||
) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
|
||||
let (end_server1, client1) = tokio::io::duplex(1024);
|
||||
let (server2, end_client2) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config1, server_config1) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
let (client_config2, server_config2) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// begin handshake with end_server
|
||||
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
|
||||
// process handshake with end_client
|
||||
let (end_client, startup) =
|
||||
handshake(client1, Some(&server_config1), &CancelMap::default())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
assert!(buf.is_empty());
|
||||
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
|
||||
|
||||
// give the end_server the startup parameters
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(startup.iter(), &mut buf).unwrap();
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
|
||||
// proxy messages between end_client and end_server
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = end_server.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL and return only SCRAM-SHA-256 ;)
|
||||
if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
|
||||
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_client.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
message = end_client.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
|
||||
if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
|
||||
let sasl_message = &message[1+4+19+4..];
|
||||
let mut new_message = b"n,,".to_vec();
|
||||
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
|
||||
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_server.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
else => { break }
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(end_server1, end_client2, client_config1, server_config2)
|
||||
}
|
||||
|
||||
/// taken from tokio-postgres
|
||||
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
T::Error: Debug,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::ssl_request(&mut buf);
|
||||
stream.write_all(&buf).await.unwrap();
|
||||
|
||||
let mut buf = [0];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
if buf[0] != b'S' {
|
||||
panic!("ssl not supported by server");
|
||||
}
|
||||
|
||||
tls.connect(stream).await.unwrap()
|
||||
}
|
||||
|
||||
struct PgFrame;
|
||||
impl Decoder for PgFrame {
|
||||
type Item = Bytes;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if src.len() < 5 {
|
||||
src.reserve(5 - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
|
||||
if src.len() < len {
|
||||
src.reserve(len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(src.split_to(len).freeze()))
|
||||
}
|
||||
}
|
||||
impl Encoder<Bytes> for PgFrame {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
dst.extend_from_slice(&item);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// If the client doesn't support channel bindings, it can be exploited.
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like the client doesn't support channel bindings, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_failure(
|
||||
intercept: Intercept,
|
||||
channel_binding: tokio_postgres::config::ChannelBinding,
|
||||
) -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(channel_binding)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err()
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
let _server_err = proxy
|
||||
.await?
|
||||
.err()
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -36,9 +36,9 @@ impl<'a> ChannelBinding<&'a str> {
|
||||
|
||||
impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
/// Encode channel binding data as base64 for subsequent checks.
|
||||
pub fn encode<E>(
|
||||
pub fn encode<'a, E>(
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
@@ -51,12 +51,11 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
"eSws".into()
|
||||
}
|
||||
Required(mode) => {
|
||||
let msg = format!(
|
||||
"p={mode},,{data}",
|
||||
mode = mode,
|
||||
data = get_cbind_data(mode)?
|
||||
);
|
||||
base64::encode(msg).into()
|
||||
use std::io::Write;
|
||||
let mut cbind_input = vec![];
|
||||
write!(&mut cbind_input, "p={mode},,",).unwrap();
|
||||
cbind_input.extend_from_slice(get_cbind_data(mode)?);
|
||||
base64::encode(&cbind_input).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -77,7 +76,7 @@ mod tests {
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -22,9 +22,12 @@ pub use secret::ServerSecret;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// TODO: add SCRAM-SHA-256-PLUS
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
|
||||
/// A list of supported SCRAM methods.
|
||||
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
|
||||
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
|
||||
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
|
||||
|
||||
/// Decode base64 into array without any heap allocations
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
@@ -80,7 +83,11 @@ mod tests {
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
let mut exchange = Exchange::new(&secret, || NONCE, None);
|
||||
let mut exchange = Exchange::new(
|
||||
&secret,
|
||||
|| NONCE,
|
||||
crate::config::TlsServerEndPoint::Undefined,
|
||||
);
|
||||
|
||||
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
|
||||
|
||||
@@ -5,9 +5,11 @@ use super::messages::{
|
||||
};
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::config;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
struct TlsServerEndPoint;
|
||||
|
||||
impl std::fmt::Display for TlsServerEndPoint {
|
||||
@@ -43,20 +45,20 @@ pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial,
|
||||
secret,
|
||||
nonce,
|
||||
cert_digest,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,6 +73,14 @@ impl sasl::Mechanism for Exchange<'_> {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
|
||||
|
||||
// If the flag is set to "y" and the server supports channel
|
||||
// binding, the server MUST fail authentication
|
||||
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
|
||||
&& self.tls_server_end_point.supported()
|
||||
{
|
||||
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
|
||||
}
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&self.secret.salt_base64,
|
||||
@@ -94,10 +104,11 @@ impl sasl::Mechanism for Exchange<'_> {
|
||||
let client_final_message = ClientFinalMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| {
|
||||
self.cert_digest
|
||||
.map(base64::encode)
|
||||
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point {
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => {
|
||||
Err(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
}
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
@@ -17,7 +18,7 @@ use tokio_rustls::server::TlsStream;
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
framed: Framed<S>,
|
||||
pub(crate) framed: Framed<S>,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
@@ -118,19 +119,21 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
/// NOTE: it should be possible to decompose this object as necessary.
|
||||
#[project = StreamProj]
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { #[pin] raw: S },
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { raw: S },
|
||||
Tls {
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
Tls { #[pin] tls: Box<TlsStream<S>> },
|
||||
}
|
||||
tls: Box<TlsStream<S>>,
|
||||
/// Channel binding parameter
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S: Unpin> Unpin for Stream<S> {}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
@@ -141,7 +144,17 @@ impl<S> Stream<S> {
|
||||
pub fn sni_hostname(&self) -> Option<&str> {
|
||||
match self {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls } => tls.get_ref().1.server_name(),
|
||||
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
match self {
|
||||
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
|
||||
Stream::Tls {
|
||||
tls_server_end_point,
|
||||
..
|
||||
} => *tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,12 +171,9 @@ pub enum StreamUpgradeError {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
|
||||
match self {
|
||||
Stream::Raw { raw } => {
|
||||
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
|
||||
Ok(Stream::Tls { tls })
|
||||
}
|
||||
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
|
||||
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
|
||||
}
|
||||
}
|
||||
@@ -171,50 +181,46 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_read(context, buf),
|
||||
Tls { tls } => tls.poll_read(context, buf),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_write(context, buf),
|
||||
Tls { tls } => tls.poll_write(context, buf),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_flush(context),
|
||||
Tls { tls } => tls.poll_flush(context),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_shutdown(context),
|
||||
Tls { tls } => tls.poll_shutdown(context),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,7 +1572,7 @@ class NeonAttachmentService:
|
||||
self.running = False
|
||||
return self
|
||||
|
||||
def attach_hook(self, tenant_id: TenantId, pageserver_id: int) -> int:
|
||||
def attach_hook_issue(self, tenant_id: TenantId, pageserver_id: int) -> int:
|
||||
response = requests.post(
|
||||
f"{self.env.control_plane_api}/attach-hook",
|
||||
json={"tenant_id": str(tenant_id), "node_id": pageserver_id},
|
||||
@@ -1582,6 +1582,13 @@ class NeonAttachmentService:
|
||||
assert isinstance(gen, int)
|
||||
return gen
|
||||
|
||||
def attach_hook_drop(self, tenant_id: TenantId):
|
||||
response = requests.post(
|
||||
f"{self.env.control_plane_api}/attach-hook",
|
||||
json={"tenant_id": str(tenant_id), "node_id": None},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def __enter__(self) -> "NeonAttachmentService":
|
||||
return self
|
||||
|
||||
@@ -1781,13 +1788,20 @@ class NeonPageserver(PgProtocol):
|
||||
to call into the pageserver HTTP client.
|
||||
"""
|
||||
if self.env.attachment_service is not None:
|
||||
generation = self.env.attachment_service.attach_hook(tenant_id, self.id)
|
||||
generation = self.env.attachment_service.attach_hook_issue(tenant_id, self.id)
|
||||
else:
|
||||
generation = None
|
||||
|
||||
client = self.http_client()
|
||||
return client.tenant_attach(tenant_id, config, config_null, generation=generation)
|
||||
|
||||
def tenant_detach(self, tenant_id: TenantId):
|
||||
if self.env.attachment_service is not None:
|
||||
self.env.attachment_service.attach_hook_drop(tenant_id)
|
||||
|
||||
client = self.http_client()
|
||||
return client.tenant_detach(tenant_id)
|
||||
|
||||
|
||||
def append_pageserver_param_overrides(
|
||||
params_to_update: List[str],
|
||||
@@ -2626,6 +2640,11 @@ class Endpoint(PgProtocol):
|
||||
|
||||
return self
|
||||
|
||||
def get_metrics_str(self) -> str:
|
||||
request_result = requests.get(f"http://localhost:{self.http_port}/metrics")
|
||||
request_result.raise_for_status()
|
||||
return request_result.text
|
||||
|
||||
def __enter__(self) -> "Endpoint":
|
||||
return self
|
||||
|
||||
|
||||
@@ -6,11 +6,15 @@ def test_build_info_metric(neon_env_builder: NeonEnvBuilder, link_proxy: NeonPro
|
||||
neon_env_builder.num_safekeepers = 1
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
env.neon_cli.create_branch("test_build_info_metric")
|
||||
endpoint = env.endpoints.create_start("test_build_info_metric")
|
||||
|
||||
parsed_metrics = {}
|
||||
|
||||
parsed_metrics["pageserver"] = parse_metrics(env.pageserver.http_client().get_metrics_str())
|
||||
parsed_metrics["safekeeper"] = parse_metrics(env.safekeepers[0].http_client().get_metrics_str())
|
||||
parsed_metrics["proxy"] = parse_metrics(link_proxy.get_metrics())
|
||||
parsed_metrics["compute_ctl"] = parse_metrics(endpoint.get_metrics_str())
|
||||
|
||||
for _component, metrics in parsed_metrics.items():
|
||||
sample = metrics.query_one("libmetrics_build_info")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
@@ -14,62 +15,33 @@ from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
|
||||
# Check that the extension is not already in the share_dir_path_ext
|
||||
# if it is, skip the test
|
||||
#
|
||||
# After the test is done, cleanup the control file and the extension directory
|
||||
# use neon_env_builder_local fixture to override the default neon_env_builder fixture
|
||||
# and use a test-specific pg_install instead of shared one
|
||||
@pytest.fixture(scope="function")
|
||||
def ext_file_cleanup(pg_bin):
|
||||
out = pg_bin.run_capture("pg_config --sharedir".split())
|
||||
share_dir_path = Path(f"{out}.stdout").read_text().strip()
|
||||
log.info(f"share_dir_path: {share_dir_path}")
|
||||
share_dir_path_ext = os.path.join(share_dir_path, "extension")
|
||||
def neon_env_builder_local(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
test_output_dir: Path,
|
||||
pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
) -> NeonEnvBuilder:
|
||||
test_local_pginstall = test_output_dir / "pg_install"
|
||||
log.info(f"copy {pg_distrib_dir} to {test_local_pginstall}")
|
||||
shutil.copytree(
|
||||
pg_distrib_dir / pg_version.v_prefixed, test_local_pginstall / pg_version.v_prefixed
|
||||
)
|
||||
|
||||
log.info(f"share_dir_path_ext: {share_dir_path_ext}")
|
||||
neon_env_builder.pg_distrib_dir = test_local_pginstall
|
||||
log.info(f"local neon_env_builder.pg_distrib_dir: {neon_env_builder.pg_distrib_dir}")
|
||||
|
||||
# if file is already in the share_dir_path_ext, skip the test
|
||||
if os.path.isfile(os.path.join(share_dir_path_ext, "anon.control")):
|
||||
log.info("anon.control is already in the share_dir_path_ext, skipping the test")
|
||||
yield False
|
||||
return
|
||||
else:
|
||||
yield True
|
||||
|
||||
# cleanup the control file
|
||||
if os.path.isfile(os.path.join(share_dir_path_ext, "anon.control")):
|
||||
os.unlink(os.path.join(share_dir_path_ext, "anon.control"))
|
||||
log.info("anon.control was removed from the share_dir_path_ext")
|
||||
|
||||
# remove the extension directory recursively
|
||||
if os.path.isdir(os.path.join(share_dir_path_ext, "anon")):
|
||||
directories_to_clean: List[Path] = []
|
||||
for f in Path(os.path.join(share_dir_path_ext, "anon")).iterdir():
|
||||
if f.is_file():
|
||||
log.info(f"Removing file {f}")
|
||||
f.unlink()
|
||||
elif f.is_dir():
|
||||
directories_to_clean.append(f)
|
||||
|
||||
for directory_to_clean in reversed(directories_to_clean):
|
||||
if not os.listdir(directory_to_clean):
|
||||
log.info(f"Removing empty directory {directory_to_clean}")
|
||||
directory_to_clean.rmdir()
|
||||
|
||||
os.rmdir(os.path.join(share_dir_path_ext, "anon"))
|
||||
log.info("anon directory was removed from the share_dir_path_ext")
|
||||
return neon_env_builder
|
||||
|
||||
|
||||
def test_remote_extensions(
|
||||
httpserver: HTTPServer,
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
neon_env_builder_local: NeonEnvBuilder,
|
||||
httpserver_listen_address,
|
||||
pg_version,
|
||||
ext_file_cleanup,
|
||||
):
|
||||
if ext_file_cleanup is False:
|
||||
log.info("test_remote_extensions skipped")
|
||||
return
|
||||
|
||||
if pg_version == PgVersion.V16:
|
||||
pytest.skip("TODO: PG16 extension building")
|
||||
|
||||
@@ -79,7 +51,8 @@ def test_remote_extensions(
|
||||
(host, port) = httpserver_listen_address
|
||||
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
|
||||
|
||||
archive_path = f"latest/v{pg_version}/extensions/anon.tar.zst"
|
||||
build_tag = os.environ.get("BUILD_TAG", "latest")
|
||||
archive_path = f"{build_tag}/v{pg_version}/extensions/anon.tar.zst"
|
||||
|
||||
def endpoint_handler_build_tag(request: Request) -> Response:
|
||||
log.info(f"request: {request}")
|
||||
@@ -88,6 +61,7 @@ def test_remote_extensions(
|
||||
file_path = f"test_runner/regress/data/extension_test/5670669815/v{pg_version}/extensions/anon.tar.zst"
|
||||
file_size = os.path.getsize(file_path)
|
||||
fh = open(file_path, "rb")
|
||||
|
||||
return Response(
|
||||
fh,
|
||||
mimetype="application/octet-stream",
|
||||
@@ -104,12 +78,10 @@ def test_remote_extensions(
|
||||
|
||||
# Start a compute node with remote_extension spec
|
||||
# and check that it can download the extensions and use them to CREATE EXTENSION.
|
||||
env = neon_env_builder.init_start()
|
||||
tenant_id, _ = env.neon_cli.create_tenant()
|
||||
env.neon_cli.create_timeline("test_remote_extensions", tenant_id=tenant_id)
|
||||
env = neon_env_builder_local.init_start()
|
||||
env.neon_cli.create_branch("test_remote_extensions")
|
||||
endpoint = env.endpoints.create(
|
||||
"test_remote_extensions",
|
||||
tenant_id=tenant_id,
|
||||
config_lines=["log_min_messages=debug3"],
|
||||
)
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ def test_deferred_deletion(neon_env_builder: NeonEnvBuilder):
|
||||
|
||||
# Now advance the generation in the control plane: subsequent validations
|
||||
# from the running pageserver will fail. No more deletions should happen.
|
||||
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
|
||||
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
|
||||
generate_uploads_and_deletions(env, init=False)
|
||||
|
||||
assert_deletion_queue(ps_http, lambda n: n > 0)
|
||||
@@ -397,7 +397,7 @@ def test_deletion_queue_recovery(
|
||||
if keep_attachment == KeepAttachment.LOSE:
|
||||
some_other_pageserver = 101010
|
||||
assert env.attachment_service is not None
|
||||
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
|
||||
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
|
||||
|
||||
env.pageserver.start()
|
||||
|
||||
|
||||
@@ -336,10 +336,15 @@ def test_live_reconfig_get_evictions_low_residence_duration_metric_threshold(
|
||||
):
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
env = neon_env_builder.init_start(
|
||||
initial_tenant_conf={
|
||||
# disable compaction so that it will not download the layer for repartitioning
|
||||
"compaction_period": "0s"
|
||||
}
|
||||
)
|
||||
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
|
||||
|
||||
(tenant_id, timeline_id) = env.neon_cli.create_tenant()
|
||||
(tenant_id, timeline_id) = env.initial_tenant, env.initial_timeline
|
||||
ps_http = env.pageserver.http_client()
|
||||
|
||||
def get_metric():
|
||||
|
||||
Reference in New Issue
Block a user