Compare commits

..

10 Commits

Author SHA1 Message Date
Anastasia Lubennikova
a769a43767 Add build_info_metric with git_version and build_tag for compute_ctl 2023-11-28 17:23:24 +00:00
Christian Schwarz
286f34dfce test suite: add method for generation-aware detachment of a tenant (#5939)
Part of getpage@lsn benchmark epic:
https://github.com/neondatabase/neon/issues/5771
2023-11-28 09:51:37 +00:00
Sasha Krassovsky
f290b27378 Fix check for if shmem is valid to take into account detached shmem (#5937)
## Problem
We can segfault if we update connstr inside of a process that has
detached from shmem (e.g. inside stats collector)
## Summary of changes
Add a check to make sure we're not detached
2023-11-28 03:14:42 +00:00
Sasha Krassovsky
4cd18fcebd Compile wal2json (#5893)
Add wal2json extension
2023-11-27 18:17:26 -08:00
Anastasia Lubennikova
4c29e0594e Update neon extension relocatable for existing installations (#5943) 2023-11-27 23:29:24 +00:00
Anastasia Lubennikova
3c56a4dd18 Make neon extension relocatable to allow SET SCHEMA (#5942) 2023-11-27 21:45:41 +00:00
Conrad Ludgate
316309c85b channel binding (#5683)
## Problem

channel binding protects scram from sophisticated MITM attacks where the
attacker is able to produce 'valid' TLS certificates.

## Summary of changes

get the tls-server-end-point channel binding, and verify it is correct
for the SCRAM-SHA-256-PLUS authentication flow
2023-11-27 21:45:15 +00:00
Arpad Müller
e09bb9974c bootstrap_timeline: rename initdb_path to pgdata_path (#5931)
This is a rename without functional changes, in preparation for #5912.

Split off from #5912 as per review request.
2023-11-27 20:14:39 +00:00
Anastasia Lubennikova
5289f341ce Use test specific directory in test_remote_extensions (#5938) 2023-11-27 18:57:58 +00:00
Joonas Koivunen
683ec2417c deflake: test_live_reconfig_get_evictions_low_residence_... (#5926)
- disable extra tenant
- disable compaction which could try to repartition while we assert

Split from #5108.
2023-11-27 15:20:54 +02:00
36 changed files with 779 additions and 582 deletions

View File

@@ -404,7 +404,7 @@ jobs:
uses: ./.github/actions/save-coverage-data
regress-tests:
needs: [ check-permissions, build-neon ]
needs: [ check-permissions, build-neon, tag ]
runs-on: [ self-hosted, gen3, large ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
@@ -436,6 +436,7 @@ jobs:
env:
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
BUILD_TAG: ${{ needs.tag.outputs.build-tag }}
- name: Merge and upload coverage data
if: matrix.build_type == 'debug' && matrix.pg_version == 'v14'

15
Cargo.lock generated
View File

@@ -1133,7 +1133,9 @@ dependencies = [
"compute_api",
"flate2",
"futures",
"git-version",
"hyper",
"metrics",
"notify",
"num_cpus",
"opentelemetry",
@@ -2610,17 +2612,6 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nostarve_queue"
version = "0.1.0"
dependencies = [
"futures",
"rand 0.8.5",
"scopeguard",
"tokio",
"tracing",
]
[[package]]
name = "notify"
version = "5.2.0"
@@ -2962,7 +2953,6 @@ dependencies = [
"itertools",
"metrics",
"nix 0.26.2",
"nostarve_queue",
"num-traits",
"num_cpus",
"once_cell",
@@ -3517,6 +3507,7 @@ dependencies = [
"pbkdf2",
"pin-project-lite",
"postgres-native-tls",
"postgres-protocol",
"postgres_backend",
"pq_proto",
"prometheus",

View File

@@ -27,7 +27,6 @@ members = [
"libs/postgres_ffi/wal_craft",
"libs/vm_monitor",
"libs/walproposer",
"libs/nostarve_queue",
]
[workspace.package]
@@ -38,7 +37,6 @@ license = "Apache-2.0"
[workspace.dependencies]
anyhow = { version = "1.0", features = ["backtrace"] }
arc-swap = "1.6"
async-channel = "1.9.0"
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
azure_core = "0.16"
azure_identity = "0.16"
@@ -193,7 +191,6 @@ tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
walproposer = { version = "0.1", path = "./libs/walproposer/" }
nostarve_queue = { path = "./libs/nostarve_queue" }
## Common library dependency
workspace_hack = { version = "0.1", path = "./workspace_hack/" }

View File

@@ -714,6 +714,24 @@ RUN wget https://github.com/pksunkara/pgx_ulid/archive/refs/tags/v0.1.3.tar.gz -
cargo pgrx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
#########################################################################################
#
# Layer "wal2json-build"
# Compile "wal2json" extension
#
#########################################################################################
FROM build-deps AS wal2json-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PATH "/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/eulerto/wal2json/archive/refs/tags/wal2json_2_5.tar.gz && \
echo "b516653575541cf221b99cf3f8be9b6821f6dbcfc125675c85f35090f824f00e wal2json_2_5.tar.gz" | sha256sum --check && \
mkdir wal2json-src && cd wal2json-src && tar xvzf ../wal2json_2_5.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/wal2json.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -750,6 +768,7 @@ COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \

View File

@@ -12,8 +12,10 @@ cfg-if.workspace = true
clap.workspace = true
flate2.workspace = true
futures.workspace = true
git-version.workspace = true
hyper = { workspace = true, features = ["full"] }
notify.workspace = true
metrics.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
postgres.workspace = true

View File

@@ -57,18 +57,27 @@ use compute_tools::logger::*;
use compute_tools::monitor::launch_monitor;
use compute_tools::params::*;
use compute_tools::spec::*;
use metrics::set_build_info_metric;
use utils::{project_build_tag, project_git_version};
// this is an arbitrary build tag. Fine as a default / for testing purposes
// in-case of not-set environment var
const BUILD_TAG_DEFAULT: &str = "latest";
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
fn main() -> Result<()> {
init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
let build_tag = option_env!("BUILD_TAG")
.unwrap_or(BUILD_TAG_DEFAULT)
.to_string();
info!("build_tag: {build_tag}");
info!("Version: {GIT_VERSION}");
info!("Build_tag: {BUILD_TAG}");
set_build_info_metric(GIT_VERSION, BUILD_TAG);
let matches = cli().get_matches();
let pgbin_default = String::from("postgres");

View File

@@ -11,7 +11,8 @@ use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIErr
use anyhow::Result;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use hyper::{header::CONTENT_TYPE, Body, Method, Request, Response, Server, StatusCode};
use metrics::{Encoder, TextEncoder};
use num_cpus;
use serde_json;
use tokio::task;
@@ -51,6 +52,20 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
Response::new(Body::from(serde_json::to_string(&status_response).unwrap()))
}
// prometheus metrics
(&Method::GET, "/metrics") => {
let mut buffer = vec![];
let metrics = metrics::gather();
let encoder = TextEncoder::new();
encoder.encode(&metrics, &mut buffer).unwrap();
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap()
}
// Startup metrics in JSON format. Keep /metrics reserved for a possible
// future use for Prometheus metrics format.
(&Method::GET, "/metrics.json") => {

View File

@@ -687,6 +687,9 @@ pub fn handle_extension_neon(client: &mut Client) -> Result<()> {
info!("create neon extension with query: {}", query);
client.simple_query(query)?;
query = "UPDATE pg_extension SET extrelocatable = true WHERE extname = 'neon'";
client.simple_query(query)?;
query = "ALTER EXTENSION neon SET SCHEMA neon";
info!("alter neon extension schema with query: {}", query);
client.simple_query(query)?;

View File

@@ -21,7 +21,7 @@ use pageserver_api::models::{
use pageserver_api::shard::TenantShardId;
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, ClientBuilder, RequestBuilder, Response};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::auth::{Claims, Scope};
@@ -99,7 +99,7 @@ impl PageServerNode {
pg_connection_config: PgConnectionConfig::new_host_port(host, port),
conf: conf.clone(),
env: env.clone(),
http_client: ClientBuilder::new().timeout(None).build().unwrap(),
http_client: Client::new(),
http_base_url: format!("http://{}/v1", conf.listen_http_addr),
}
}

View File

@@ -1,14 +0,0 @@
[package]
name = "nostarve_queue"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
scopeguard.workspace = true
tracing.workspace = true
[dev-dependencies]
futures.workspace = true
rand.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time"] }

View File

@@ -1,316 +0,0 @@
//! Synchronization primitive to prevent starvation among concurrent tasks that do the same work.
use std::{
collections::VecDeque,
fmt,
future::poll_fn,
sync::Mutex,
task::{Poll, Waker},
};
pub struct Queue<T> {
inner: Mutex<Inner<T>>,
}
struct Inner<T> {
waiters: VecDeque<usize>,
free: VecDeque<usize>,
slots: Vec<Option<(Option<Waker>, Option<T>)>>,
}
#[derive(Clone, Copy)]
pub struct Position<'q, T> {
idx: usize,
queue: &'q Queue<T>,
}
impl<T> fmt::Debug for Position<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Position").field("idx", &self.idx).finish()
}
}
impl<T> Inner<T> {
#[cfg(not(test))]
#[inline]
fn integrity_check(&self) {}
#[cfg(test)]
fn integrity_check(&self) {
use std::collections::HashSet;
let waiters = self.waiters.iter().copied().collect::<HashSet<_>>();
let free = self.free.iter().copied().collect::<HashSet<_>>();
for (slot_idx, slot) in self.slots.iter().enumerate() {
match slot {
None => {
assert!(!waiters.contains(&slot_idx));
assert!(free.contains(&slot_idx));
}
Some((None, None)) => {
assert!(waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
Some((Some(_), Some(_))) => {
assert!(!waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
Some((Some(_), None)) => {
assert!(waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
Some((None, Some(_))) => {
assert!(!waiters.contains(&slot_idx));
assert!(!free.contains(&slot_idx));
}
}
}
}
}
impl<T> Queue<T> {
pub fn new(size: usize) -> Self {
Queue {
inner: Mutex::new(Inner {
waiters: VecDeque::new(),
free: (0..size).collect(),
slots: {
let mut v = Vec::with_capacity(size);
v.resize_with(size, || None);
v
},
}),
}
}
pub fn begin(&self) -> Result<Position<T>, ()> {
#[cfg(test)]
tracing::trace!("get in line locking inner");
let mut inner = self.inner.lock().unwrap();
inner.integrity_check();
let my_waitslot_idx = inner
.free
.pop_front()
.expect("can't happen, len(slots) = len(waiters");
inner.waiters.push_back(my_waitslot_idx);
let prev = inner.slots[my_waitslot_idx].replace((None, None));
assert!(prev.is_none());
inner.integrity_check();
Ok(Position {
idx: my_waitslot_idx,
queue: &self,
})
}
}
impl<'q, T> Position<'q, T> {
pub fn complete_and_wait(self, datum: T) -> impl std::future::Future<Output = T> + 'q {
#[cfg(test)]
tracing::trace!("found victim locking waiters");
let mut inner = self.queue.inner.lock().unwrap();
inner.integrity_check();
let winner_idx = inner.waiters.pop_front().expect("we put ourselves in");
#[cfg(test)]
tracing::trace!(winner_idx, "putting victim into next waiters slot");
let winner_slot = inner.slots[winner_idx].as_mut().unwrap();
let prev = winner_slot.1.replace(datum);
assert!(
prev.is_none(),
"ensure we didn't mess up this simple ring buffer structure"
);
if let Some(waker) = winner_slot.0.take() {
#[cfg(test)]
tracing::trace!(winner_idx, "waking up winner");
waker.wake()
}
inner.integrity_check();
drop(inner); // the poll_fn locks it again
let mut poll_num = 0;
let mut drop_guard = Some(scopeguard::guard((), |()| {
panic!("must not drop this future until Ready");
}));
// take the victim that was found by someone else
poll_fn(move |cx| {
let my_waitslot_idx = self.idx;
poll_num += 1;
#[cfg(test)]
tracing::trace!(poll_num, "poll_fn locking waiters");
let mut inner = self.queue.inner.lock().unwrap();
inner.integrity_check();
let my_waitslot = inner.slots[self.idx].as_mut().unwrap();
// assert!(
// poll_num <= 2,
// "once we place the waker in the slot, next wakeup should have a result: {}",
// my_waitslot.1.is_some()
// );
if let Some(res) = my_waitslot.1.take() {
#[cfg(test)]
tracing::trace!(poll_num, "have cache slot");
// above .take() resets the waiters slot to None
debug_assert!(my_waitslot.0.is_none());
debug_assert!(my_waitslot.1.is_none());
inner.slots[my_waitslot_idx] = None;
inner.free.push_back(my_waitslot_idx);
let _ = scopeguard::ScopeGuard::into_inner(drop_guard.take().unwrap());
inner.integrity_check();
return Poll::Ready(res);
}
// assert_eq!(poll_num, 1);
if !my_waitslot
.0
.as_ref()
.map(|existing| cx.waker().will_wake(existing))
.unwrap_or(false)
{
let prev = my_waitslot.0.replace(cx.waker().clone());
#[cfg(test)]
tracing::trace!(poll_num, prev_is_some = prev.is_some(), "updating waker");
}
inner.integrity_check();
#[cfg(test)]
tracing::trace!(poll_num, "waiting to be woken up");
Poll::Pending
})
}
}
#[cfg(test)]
mod test {
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::Poll,
time::Duration,
};
use rand::RngCore;
#[tokio::test]
async fn in_order_completion_and_wait() {
let queue = super::Queue::new(2);
let q1 = queue.begin().unwrap();
let q2 = queue.begin().unwrap();
assert_eq!(q1.complete_and_wait(23).await, 23);
assert_eq!(q2.complete_and_wait(42).await, 42);
}
#[tokio::test]
async fn out_of_order_completion_and_wait() {
let queue = super::Queue::new(2);
let q1 = queue.begin().unwrap();
let q2 = queue.begin().unwrap();
let mut q2compfut = q2.complete_and_wait(23);
match futures::poll!(&mut q2compfut) {
Poll::Pending => {}
Poll::Ready(_) => panic!("should not be ready yet, it's queued after q1"),
}
let q1res = q1.complete_and_wait(42).await;
assert_eq!(q1res, 23);
let q2res = q2compfut.await;
assert_eq!(q2res, 42);
}
#[tokio::test]
async fn in_order_completion_out_of_order_wait() {
let queue = super::Queue::new(2);
let q1 = queue.begin().unwrap();
let q2 = queue.begin().unwrap();
let mut q1compfut = q1.complete_and_wait(23);
let mut q2compfut = q2.complete_and_wait(42);
match futures::poll!(&mut q2compfut) {
Poll::Pending => {
unreachable!("q2 should be ready, it wasn't first but q1 is serviced already")
}
Poll::Ready(x) => assert_eq!(x, 42),
}
assert_eq!(futures::poll!(&mut q1compfut), Poll::Ready(23));
}
#[tokio::test(flavor = "multi_thread")]
async fn stress() {
let ntasks = 8;
let queue_size = 8;
let queue = Arc::new(super::Queue::new(queue_size));
let stop = Arc::new(AtomicBool::new(false));
let mut tasks = vec![];
for i in 0..ntasks {
let jh = tokio::spawn({
let queue = Arc::clone(&queue);
let stop = Arc::clone(&stop);
async move {
while !stop.load(Ordering::Relaxed) {
let q = queue.begin().unwrap();
for _ in 0..(rand::thread_rng().next_u32() % 10_000) {
std::hint::spin_loop();
}
q.complete_and_wait(i).await;
tokio::task::yield_now().await;
}
}
});
tasks.push(jh);
}
tokio::time::sleep(Duration::from_secs(10)).await;
stop.store(true, Ordering::Relaxed);
for t in tasks {
t.await.unwrap();
}
}
#[test]
fn stress_two_runtimes_shared_queue() {
std::thread::scope(|s| {
let ntasks = 8;
let queue_size = 8;
let queue = Arc::new(super::Queue::new(queue_size));
let stop = Arc::new(AtomicBool::new(false));
for i in 0..ntasks {
s.spawn({
let queue = Arc::clone(&queue);
let stop = Arc::clone(&stop);
move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
while !stop.load(Ordering::Relaxed) {
let q = queue.begin().unwrap();
for _ in 0..(rand::thread_rng().next_u32() % 10_000) {
std::hint::spin_loop();
}
q.complete_and_wait(i).await;
tokio::task::yield_now().await;
}
});
}
});
}
std::thread::sleep(Duration::from_secs(10));
stop.store(true, Ordering::Relaxed);
});
}
}

View File

@@ -37,7 +37,6 @@ humantime-serde.workspace = true
hyper.workspace = true
itertools.workspace = true
nix.workspace = true
nostarve_queue.workspace = true
# hack to get the number of worker threads tokio uses
num_cpus = { version = "1.15" }
num-traits.workspace = true

View File

@@ -314,6 +314,7 @@ static PAGE_CACHE_ERRORS: Lazy<IntCounterVec> = Lazy::new(|| {
#[strum(serialize_all = "kebab_case")]
pub(crate) enum PageCacheErrorKind {
AcquirePinnedSlotTimeout,
EvictIterLimit,
}
pub(crate) fn page_cache_errors_inc(error_kind: PageCacheErrorKind) {

View File

@@ -83,7 +83,6 @@ use std::{
use anyhow::Context;
use once_cell::sync::OnceCell;
use tracing::instrument;
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
@@ -253,9 +252,6 @@ pub struct PageCache {
next_evict_slot: AtomicUsize,
size_metrics: &'static PageCacheSizeMetrics,
find_victim_waiters:
nostarve_queue::Queue<(usize, tokio::sync::RwLockWriteGuard<'static, SlotInner>)>,
}
struct PinnedSlotsPermit(tokio::sync::OwnedSemaphorePermit);
@@ -434,9 +430,8 @@ impl PageCache {
///
/// Store an image of the given page in the cache.
///
#[cfg_attr(test, instrument(skip_all, level = "trace", fields(%key, %lsn)))]
pub async fn memorize_materialized_page(
&'static self,
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
key: Key,
@@ -527,9 +522,8 @@ impl PageCache {
// Section 1.2: Public interface functions for working with immutable file pages.
#[cfg_attr(test, instrument(skip_all, level = "trace", fields(?file_id, ?blkno)))]
pub async fn read_immutable_buf(
&'static self,
&self,
file_id: FileId,
blkno: u32,
ctx: &RequestContext,
@@ -635,7 +629,7 @@ impl PageCache {
/// ```
///
async fn lock_for_read(
&'static self,
&self,
cache_key: &mut CacheKey,
ctx: &RequestContext,
) -> anyhow::Result<ReadBufResult> {
@@ -857,15 +851,10 @@ impl PageCache {
///
/// On return, the slot is empty and write-locked.
async fn find_victim(
&'static self,
&self,
_permit_witness: &PinnedSlotsPermit,
) -> anyhow::Result<(usize, tokio::sync::RwLockWriteGuard<SlotInner>)> {
let nostarve_position = self.find_victim_waiters.begin()
.expect("we initialize the nostarve queue to the same size as the slots semaphore, and the caller is presenting a permit");
let span = tracing::info_span!("find_victim", ?nostarve_position);
let _enter = span.enter();
let iter_limit = self.slots.len() * 10;
let mut iters = 0;
loop {
iters += 1;
@@ -877,8 +866,41 @@ impl PageCache {
let mut inner = match slot.inner.try_write() {
Ok(inner) => inner,
Err(_err) => {
if iters > self.slots.len() * (MAX_USAGE_COUNT as usize) {
unreachable!("find_victim_waiters prevents starvation");
if iters > iter_limit {
// NB: Even with the permits, there's no hard guarantee that we will find a slot with
// any particular number of iterations: other threads might race ahead and acquire and
// release pins just as we're scanning the array.
//
// Imagine that nslots is 2, and as starting point, usage_count==1 on all
// slots. There are two threads running concurrently, A and B. A has just
// acquired the permit from the semaphore.
//
// A: Look at slot 1. Its usage_count == 1, so decrement it to zero, and continue the search
// B: Acquire permit.
// B: Look at slot 2, decrement its usage_count to zero and continue the search
// B: Look at slot 1. Its usage_count is zero, so pin it and bump up its usage_count to 1.
// B: Release pin and permit again
// B: Acquire permit.
// B: Look at slot 2. Its usage_count is zero, so pin it and bump up its usage_count to 1.
// B: Release pin and permit again
//
// Now we're back in the starting situation that both slots have
// usage_count 1, but A has now been through one iteration of the
// find_victim() loop. This can repeat indefinitely and on each
// iteration, A's iteration count increases by one.
//
// So, even though the semaphore for the permits is fair, the victim search
// itself happens in parallel and is not fair.
// Hence even with a permit, a task can theoretically be starved.
// To avoid this, we'd need tokio to give priority to tasks that are holding
// permits for longer.
// Note that just yielding to tokio during iteration without such
// priority boosting is likely counter-productive. We'd just give more opportunities
// for B to bump usage count, further starving A.
crate::metrics::page_cache_errors_inc(
crate::metrics::PageCacheErrorKind::EvictIterLimit,
);
anyhow::bail!("exceeded evict iter limit");
}
continue;
}
@@ -889,8 +911,7 @@ impl PageCache {
inner.key = None;
}
crate::metrics::PAGE_CACHE_FIND_VICTIMS_ITERS_TOTAL.inc_by(iters as u64);
return Ok(nostarve_position.complete_and_wait((slot_idx, inner)).await);
return Ok((slot_idx, inner));
}
}
}
@@ -934,7 +955,6 @@ impl PageCache {
next_evict_slot: AtomicUsize::new(0),
size_metrics,
pinned_slots: Arc::new(tokio::sync::Semaphore::new(num_pages)),
find_victim_waiters: ::nostarve_queue::Queue::new(num_pages),
}
}
}

View File

@@ -2908,7 +2908,7 @@ impl Tenant {
};
// create a `tenant/{tenant_id}/timelines/basebackup-{timeline_id}.{TEMP_FILE_SUFFIX}/`
// temporary directory for basebackup files for the given timeline.
let initdb_path = path_with_suffix_extension(
let pgdata_path = path_with_suffix_extension(
self.conf
.timelines_path(&self.tenant_id)
.join(format!("basebackup-{timeline_id}")),
@@ -2917,26 +2917,25 @@ impl Tenant {
// an uninit mark was placed before, nothing else can access this timeline files
// current initdb was not run yet, so remove whatever was left from the previous runs
if initdb_path.exists() {
fs::remove_dir_all(&initdb_path).with_context(|| {
format!("Failed to remove already existing initdb directory: {initdb_path}")
if pgdata_path.exists() {
fs::remove_dir_all(&pgdata_path).with_context(|| {
format!("Failed to remove already existing initdb directory: {pgdata_path}")
})?;
}
// Init temporarily repo to get bootstrap data, this creates a directory in the `initdb_path` path
run_initdb(self.conf, &initdb_path, pg_version)?;
// Init temporarily repo to get bootstrap data, this creates a directory in the `pgdata_path` path
run_initdb(self.conf, &pgdata_path, pg_version)?;
// this new directory is very temporary, set to remove it immediately after bootstrap, we don't need it
scopeguard::defer! {
if let Err(e) = fs::remove_dir_all(&initdb_path) {
if let Err(e) = fs::remove_dir_all(&pgdata_path) {
// this is unlikely, but we will remove the directory on pageserver restart or another bootstrap call
error!("Failed to remove temporary initdb directory '{initdb_path}': {e}");
error!("Failed to remove temporary initdb directory '{pgdata_path}': {e}");
}
}
let pgdata_path = &initdb_path;
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(pgdata_path)?.align();
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(&pgdata_path)?.align();
// Upload the created data dir to S3
if let Some(storage) = &self.remote_storage {
let pgdata_zstd = import_datadir::create_tar_zst(pgdata_path).await?;
let pgdata_zstd = import_datadir::create_tar_zst(&pgdata_path).await?;
let pgdata_zstd = Bytes::from(pgdata_zstd);
backoff::retry(
|| async {
@@ -2986,7 +2985,7 @@ impl Tenant {
import_datadir::import_timeline_from_postgres_datadir(
unfinished_timeline,
pgdata_path,
&pgdata_path,
pgdata_lsn,
ctx,
)

View File

@@ -21,6 +21,7 @@
#include "storage/buf_internals.h"
#include "storage/lwlock.h"
#include "storage/ipc.h"
#include "storage/pg_shmem.h"
#include "c.h"
#include "postmaster/interrupt.h"
@@ -87,6 +88,12 @@ bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) =
static bool pageserver_flush(void);
static void pageserver_disconnect(void);
static bool
PagestoreShmemIsValid()
{
return pagestore_shared && UsedShmemSegAddr;
}
static bool
CheckPageserverConnstring(char **newval, void **extra, GucSource source)
{
@@ -96,7 +103,7 @@ CheckPageserverConnstring(char **newval, void **extra, GucSource source)
static void
AssignPageserverConnstring(const char *newval, void *extra)
{
if(!pagestore_shared)
if(!PagestoreShmemIsValid())
return;
LWLockAcquire(pagestore_shared->lock, LW_EXCLUSIVE);
strlcpy(pagestore_shared->pageserver_connstring, newval, MAX_PAGESERVER_CONNSTRING_SIZE);
@@ -107,7 +114,7 @@ AssignPageserverConnstring(const char *newval, void *extra)
static bool
CheckConnstringUpdated()
{
if(!pagestore_shared)
if(!PagestoreShmemIsValid())
return false;
return pagestore_local_counter < pg_atomic_read_u64(&pagestore_shared->update_counter);
}
@@ -115,7 +122,7 @@ CheckConnstringUpdated()
static void
ReloadConnstring()
{
if(!pagestore_shared)
if(!PagestoreShmemIsValid())
return;
LWLockAcquire(pagestore_shared->lock, LW_SHARED);
strlcpy(local_pageserver_connstring, pagestore_shared->pageserver_connstring, sizeof(local_pageserver_connstring));

View File

@@ -2,3 +2,4 @@
comment = 'cloud storage for PostgreSQL'
default_version = '1.1'
module_pathname = '$libdir/neon'
relocatable = true

View File

@@ -76,3 +76,4 @@ tokio-util.workspace = true
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true
postgres-protocol.workspace = true

View File

@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
use tokio_postgres::config::AuthKeys;
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
use crate::stream::Stream;
use crate::{
auth::{self, ClientCredentials},
config::AuthenticationConfig,
@@ -131,7 +132,7 @@ async fn auth_quirks_creds(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
@@ -165,7 +166,7 @@ async fn auth_quirks(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
@@ -241,7 +242,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
pub async fn authenticate(
&mut self,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,

View File

@@ -6,7 +6,7 @@ use crate::{
console::{self, AuthInfo, ConsoleReqExtra},
proxy::LatencyTimer,
sasl, scram,
stream::PqStream,
stream::{PqStream, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -15,7 +15,7 @@ pub(super) async fn authenticate(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {

View File

@@ -2,7 +2,7 @@ use super::{AuthSuccess, ComputeCredentials};
use crate::{
auth::{self, AuthFlow, ClientCredentials},
proxy::LatencyTimer,
stream,
stream::{self, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -12,7 +12,7 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn cleartext_hack(
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
warn!("cleartext auth flow override is enabled, proceeding");
@@ -37,7 +37,7 @@ pub async fn cleartext_hack(
/// Very similar to [`cleartext_hack`], but there's a specific password format.
pub async fn password_hack(
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
warn!("project not specified, resorting to the password hack auth flow");

View File

@@ -1,16 +1,21 @@
//! Main authentication flow.
use super::{AuthErrorImpl, PasswordHackPayload};
use crate::{sasl, scram, stream::PqStream};
use crate::{
config::TlsServerEndPoint,
sasl, scram,
stream::{PqStream, Stream},
};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
/// Every authentication selector is supposed to implement this trait.
pub trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self) -> BeMessage<'_>;
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
@@ -21,8 +26,14 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
}
}
@@ -32,7 +43,7 @@ pub struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
@@ -43,37 +54,44 @@ pub struct CleartextPassword;
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, Stream, State> {
pub struct AuthFlow<'a, S, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream>,
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
tls_server_end_point: TlsServerEndPoint,
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub fn new(stream: &'a mut PqStream<S>) -> Self {
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
Self {
stream,
state: Begin,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
self.stream.write_message(&method.first_message()).await?;
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
@@ -123,9 +141,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
return Err(super::AuthError::bad_auth_method(sasl.method));
}
info!("client chooses {}", sasl.method);
let secret = self.state.0;
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(secret, rand::random, None))
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
Ok(outcome)

View File

@@ -6,6 +6,8 @@
use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
@@ -65,7 +67,7 @@ async fn main() -> anyhow::Result<()> {
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Configure TLS
let tls_config: Arc<rustls::ServerConfig> = match (
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -89,16 +91,22 @@ async fn main() -> anyhow::Result<()> {
))?
.into_iter()
.map(rustls::Certificate)
.collect()
.collect_vec()
};
rustls::ServerConfig::builder()
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into()
.into();
(tls_config, tls_server_end_point)
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -113,6 +121,7 @@ async fn main() -> anyhow::Result<()> {
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
));
@@ -134,6 +143,7 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -159,7 +169,7 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
handle_client(dest_suffix, tls_config, socket).await
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -207,6 +217,7 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
@@ -231,7 +242,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(raw.upgrade(tls_config).await?)
Ok(Stream::Tls {
tls: Box::new(raw.upgrade(tls_config).await?),
tls_server_end_point,
})
}
unexpected => {
info!(
@@ -246,9 +261,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn handle_client(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let tls_stream = ssl_handshake(stream, tls_config).await?;
let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -1,12 +1,15 @@
use crate::auth;
use anyhow::{bail, ensure, Context, Ok};
use rustls::sign;
use rustls::{sign, Certificate, PrivateKey};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
time::Duration,
};
use tracing::{error, info};
use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
@@ -27,6 +30,7 @@ pub struct MetricCollectionConfig {
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: Option<HashSet<String>>,
pub cert_resolver: Arc<CertResolver>,
}
pub struct HttpConfig {
@@ -52,7 +56,7 @@ pub fn configure_tls(
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert(key_path, cert_path, true)?;
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
@@ -64,7 +68,7 @@ pub fn configure_tls(
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert(
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
@@ -76,35 +80,97 @@ pub fn configure_tls(
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
let config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
// allow TLS 1.2 to be compatible with older client libraries
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(Arc::new(cert_resolver))
.with_cert_resolver(cert_resolver.clone())
.into();
Ok(TlsConfig {
config,
common_names: Some(common_names),
cert_resolver,
})
}
struct CertResolver {
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
default: Option<Arc<rustls::sign::CertifiedKey>>,
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl CertResolver {
fn new() -> Self {
Self {
certs: HashMap::new(),
default: None,
impl TlsServerEndPoint {
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(&cert.0)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] =
Sha256::new().chain_update(&cert.0).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
fn add_cert(
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
@@ -120,57 +186,65 @@ impl CertResolver {
keys.pop().map(rustls::PrivateKey).unwrap()
};
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.context(format!(
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
))?
)
})?
.into_iter()
.map(rustls::Certificate)
.collect()
};
let common_name = {
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
.context(format!(
"Failed to parse PEM object from bytes from file at '{cert_path}'."
))?
.1;
let common_name = pem.parse_x509()?.subject().to_string();
self.add_cert(priv_key, cert_chain, is_default)
}
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
pub fn add_cert(
&mut self,
priv_key: PrivateKey,
cert_chain: Vec<Certificate>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
let common_name = if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
.context(format!(
"Failed to parse common name from certificate at '{cert_path}'."
))?;
.context("Failed to parse common name from certificate")?;
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some(cert.clone());
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, cert);
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
fn get_common_names(&self) -> HashSet<String> {
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
@@ -178,15 +252,24 @@ impl CertResolver {
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
_client_hello: rustls::server::ClientHello,
client_hello: rustls::server::ClientHello,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = _client_hello.server_name() {
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());

View File

@@ -470,7 +470,17 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
@@ -875,7 +885,7 @@ pub async fn proxy_pass(
/// Thin connection context.
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<S>,
stream: PqStream<Stream<S>>,
/// Client credentials that we care about.
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
/// KV-dictionary with PostgreSQL connection params.
@@ -889,7 +899,7 @@ struct Client<'a, S> {
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(
stream: PqStream<S>,
stream: PqStream<Stream<S>>,
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,

View File

@@ -1,19 +1,23 @@
//! A group of high-level tests for connection establishing logic and auth.
//!
mod mitm;
use super::*;
use crate::auth::backend::TestBackend;
use crate::auth::ClientCredentials;
use crate::config::CertResolver;
use crate::console::{CachedNodeInfo, NodeInfo};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
use rstest::rstest;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::MakeRustlsConnect;
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
hostname: &str,
common_name: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
@@ -21,7 +25,15 @@ fn generate_certs(
params
})?;
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
Ok((
rustls::Certificate(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
@@ -37,7 +49,14 @@ struct ClientConfig<'a> {
impl ClientConfig<'_> {
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
self,
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
) -> anyhow::Result<
impl tokio_postgres::tls::TlsConnect<
S,
Error = impl std::fmt::Debug,
Future = impl Send,
Stream = RustlsStream<S>,
>,
> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
@@ -49,20 +68,24 @@ fn generate_tls_config<'a>(
hostname: &'a str,
common_name: &'a str,
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
let (ca, cert, key) = generate_certs(hostname)?;
let (ca, cert, key) = generate_certs(hostname, common_name)?;
let tls_config = {
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)?
.with_single_cert(vec![cert.clone()], key.clone())?
.into();
let common_names = Some([common_name.to_owned()].iter().cloned().collect());
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = Some(cert_resolver.get_common_names());
TlsConfig {
config,
common_names,
cert_resolver: Arc::new(cert_resolver),
}
};
@@ -253,6 +276,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
));
let (_client, _conn) = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
.user("user")
.dbname("db")
.password(password)
@@ -263,6 +287,30 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
proxy.await?
}
#[tokio::test]
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let (_client, _conn) = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
}
#[tokio::test]
async fn scram_auth_mock() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);

View File

@@ -0,0 +1,257 @@
//! Man-in-the-middle tests
//!
//! Channel binding should prevent a proxy server
//! - that has access to create valid certificates -
//! from controlling the TLS connection.
use std::fmt::Debug;
use super::*;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_protocol::message::frontend;
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::TlsConnect;
use tokio_util::codec::{Decoder, Encoder};
enum Intercept {
None,
Methods,
SASLResponse,
}
async fn proxy_mitm(
intercept: Intercept,
) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
let (end_server1, client1) = tokio::io::duplex(1024);
let (server2, end_client2) = tokio::io::duplex(1024);
let (client_config1, server_config1) =
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
let (client_config2, server_config2) =
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
// process handshake with end_client
let (end_client, startup) =
handshake(client1, Some(&server_config1), &CancelMap::default())
.await
.unwrap()
.unwrap();
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();
assert!(buf.is_empty());
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
// give the end_server the startup parameters
let mut buf = BytesMut::new();
frontend::startup_message(startup.iter(), &mut buf).unwrap();
end_server.send(buf.freeze()).await.unwrap();
// proxy messages between end_client and end_server
loop {
tokio::select! {
message = end_server.next() => {
match message {
Some(Ok(message)) => {
// intercept SASL and return only SCRAM-SHA-256 ;)
if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
continue;
}
end_client.send(message).await.unwrap()
}
_ => break,
}
}
message = end_client.next() => {
match message {
Some(Ok(message)) => {
// intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
let sasl_message = &message[1+4+19+4..];
let mut new_message = b"n,,".to_vec();
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
let mut buf = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
end_server.send(buf.freeze()).await.unwrap();
continue;
}
end_server.send(message).await.unwrap()
}
_ => break,
}
}
else => { break }
}
}
});
(end_server1, end_client2, client_config1, server_config2)
}
/// taken from tokio-postgres
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
T::Error: Debug,
{
let mut buf = BytesMut::new();
frontend::ssl_request(&mut buf);
stream.write_all(&buf).await.unwrap();
let mut buf = [0];
stream.read_exact(&mut buf).await.unwrap();
if buf[0] != b'S' {
panic!("ssl not supported by server");
}
tls.connect(stream).await.unwrap()
}
struct PgFrame;
impl Decoder for PgFrame {
type Item = Bytes;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 5 {
src.reserve(5 - src.len());
return Ok(None);
}
let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
if src.len() < len {
src.reserve(len - src.len());
return Ok(None);
}
Ok(Some(src.split_to(len).freeze()))
}
}
impl Encoder<Bytes> for PgFrame {
type Error = io::Error;
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.extend_from_slice(&item);
Ok(())
}
}
/// If the client doesn't support channel bindings, it can be exploited.
#[tokio::test]
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let _client_err = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
}
/// If the client chooses SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
connect_failure(
Intercept::None,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
connect_failure(
Intercept::Methods,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the MITM pretends like the client doesn't support channel bindings, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
connect_failure(
Intercept::SASLResponse,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the client chooses SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
connect_failure(
Intercept::None,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
connect_failure(
Intercept::Methods,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
connect_failure(
Intercept::SASLResponse,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
async fn connect_failure(
intercept: Intercept,
channel_binding: tokio_postgres::config::ChannelBinding,
) -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let _client_err = tokio_postgres::Config::new()
.channel_binding(channel_binding)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await
.err()
.context("client shouldn't be able to connect")?;
let _server_err = proxy
.await?
.err()
.context("server shouldn't accept client")?;
Ok(())
}

View File

@@ -36,9 +36,9 @@ impl<'a> ChannelBinding<&'a str> {
impl<T: std::fmt::Display> ChannelBinding<T> {
/// Encode channel binding data as base64 for subsequent checks.
pub fn encode<E>(
pub fn encode<'a, E>(
&self,
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
) -> Result<std::borrow::Cow<'static, str>, E> {
use ChannelBinding::*;
Ok(match self {
@@ -51,12 +51,11 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Required(mode) => {
let msg = format!(
"p={mode},,{data}",
mode = mode,
data = get_cbind_data(mode)?
);
base64::encode(msg).into()
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
base64::encode(&cbind_input).into()
}
})
}
@@ -77,7 +76,7 @@ mod tests {
];
for (cb, input) in cases {
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
}
Ok(())

View File

@@ -22,9 +22,12 @@ pub use secret::ServerSecret;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
// TODO: add SCRAM-SHA-256-PLUS
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
/// A list of supported SCRAM methods.
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
/// Decode base64 into array without any heap allocations
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
@@ -80,7 +83,11 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange = Exchange::new(&secret, || NONCE, None);
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";

View File

@@ -5,9 +5,11 @@ use super::messages::{
};
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use crate::config;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
#[derive(Debug)]
struct TlsServerEndPoint;
impl std::fmt::Display for TlsServerEndPoint {
@@ -43,20 +45,20 @@ pub struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
cert_digest: Option<&'a [u8]>,
tls_server_end_point: config::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
cert_digest: Option<&'a [u8]>,
tls_server_end_point: config::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial,
secret,
nonce,
cert_digest,
tls_server_end_point,
}
}
}
@@ -71,6 +73,14 @@ impl sasl::Mechanism for Exchange<'_> {
let client_first_message = ClientFirstMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
// If the flag is set to "y" and the server supports channel
// binding, the server MUST fail authentication
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
&& self.tls_server_end_point.supported()
{
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
}
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
&self.secret.salt_base64,
@@ -94,10 +104,11 @@ impl sasl::Mechanism for Exchange<'_> {
let client_final_message = ClientFinalMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| {
self.cert_digest
.map(base64::encode)
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point {
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => {
Err(SaslError::ChannelBindingFailed("no cert digest provided"))
}
})?;
// This might've been caused by a MITM attack

View File

@@ -1,7 +1,8 @@
use crate::config::TlsServerEndPoint;
use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
@@ -17,7 +18,7 @@ use tokio_rustls::server::TlsStream;
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
framed: Framed<S>,
pub(crate) framed: Framed<S>,
}
impl<S> PqStream<S> {
@@ -118,19 +119,21 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
}
}
pin_project! {
/// Wrapper for upgrading raw streams into secure streams.
/// NOTE: it should be possible to decompose this object as necessary.
#[project = StreamProj]
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { #[pin] raw: S },
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
Tls { #[pin] tls: Box<TlsStream<S>> },
}
tls: Box<TlsStream<S>>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
@@ -141,7 +144,17 @@ impl<S> Stream<S> {
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls } => tls.get_ref().1.server_name(),
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
Stream::Tls {
tls_server_end_point,
..
} => *tls_server_end_point,
}
}
}
@@ -158,12 +171,9 @@ pub enum StreamUpgradeError {
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
match self {
Stream::Raw { raw } => {
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
Ok(Stream::Tls { tls })
}
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}
@@ -171,50 +181,46 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_read(context, buf),
Tls { tls } => tls.poll_read(context, buf),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_write(context, buf),
Tls { tls } => tls.poll_write(context, buf),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_flush(context),
Tls { tls } => tls.poll_flush(context),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_shutdown(context),
Tls { tls } => tls.poll_shutdown(context),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
}
}
}

View File

@@ -1572,7 +1572,7 @@ class NeonAttachmentService:
self.running = False
return self
def attach_hook(self, tenant_id: TenantId, pageserver_id: int) -> int:
def attach_hook_issue(self, tenant_id: TenantId, pageserver_id: int) -> int:
response = requests.post(
f"{self.env.control_plane_api}/attach-hook",
json={"tenant_id": str(tenant_id), "node_id": pageserver_id},
@@ -1582,6 +1582,13 @@ class NeonAttachmentService:
assert isinstance(gen, int)
return gen
def attach_hook_drop(self, tenant_id: TenantId):
response = requests.post(
f"{self.env.control_plane_api}/attach-hook",
json={"tenant_id": str(tenant_id), "node_id": None},
)
response.raise_for_status()
def __enter__(self) -> "NeonAttachmentService":
return self
@@ -1781,13 +1788,20 @@ class NeonPageserver(PgProtocol):
to call into the pageserver HTTP client.
"""
if self.env.attachment_service is not None:
generation = self.env.attachment_service.attach_hook(tenant_id, self.id)
generation = self.env.attachment_service.attach_hook_issue(tenant_id, self.id)
else:
generation = None
client = self.http_client()
return client.tenant_attach(tenant_id, config, config_null, generation=generation)
def tenant_detach(self, tenant_id: TenantId):
if self.env.attachment_service is not None:
self.env.attachment_service.attach_hook_drop(tenant_id)
client = self.http_client()
return client.tenant_detach(tenant_id)
def append_pageserver_param_overrides(
params_to_update: List[str],
@@ -2626,6 +2640,11 @@ class Endpoint(PgProtocol):
return self
def get_metrics_str(self) -> str:
request_result = requests.get(f"http://localhost:{self.http_port}/metrics")
request_result.raise_for_status()
return request_result.text
def __enter__(self) -> "Endpoint":
return self

View File

@@ -6,11 +6,15 @@ def test_build_info_metric(neon_env_builder: NeonEnvBuilder, link_proxy: NeonPro
neon_env_builder.num_safekeepers = 1
env = neon_env_builder.init_start()
env.neon_cli.create_branch("test_build_info_metric")
endpoint = env.endpoints.create_start("test_build_info_metric")
parsed_metrics = {}
parsed_metrics["pageserver"] = parse_metrics(env.pageserver.http_client().get_metrics_str())
parsed_metrics["safekeeper"] = parse_metrics(env.safekeepers[0].http_client().get_metrics_str())
parsed_metrics["proxy"] = parse_metrics(link_proxy.get_metrics())
parsed_metrics["compute_ctl"] = parse_metrics(endpoint.get_metrics_str())
for _component, metrics in parsed_metrics.items():
sample = metrics.query_one("libmetrics_build_info")

View File

@@ -1,7 +1,8 @@
import os
import shutil
from contextlib import closing
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict
import pytest
from fixtures.log_helper import log
@@ -14,62 +15,33 @@ from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
# Check that the extension is not already in the share_dir_path_ext
# if it is, skip the test
#
# After the test is done, cleanup the control file and the extension directory
# use neon_env_builder_local fixture to override the default neon_env_builder fixture
# and use a test-specific pg_install instead of shared one
@pytest.fixture(scope="function")
def ext_file_cleanup(pg_bin):
out = pg_bin.run_capture("pg_config --sharedir".split())
share_dir_path = Path(f"{out}.stdout").read_text().strip()
log.info(f"share_dir_path: {share_dir_path}")
share_dir_path_ext = os.path.join(share_dir_path, "extension")
def neon_env_builder_local(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_distrib_dir: Path,
pg_version: PgVersion,
) -> NeonEnvBuilder:
test_local_pginstall = test_output_dir / "pg_install"
log.info(f"copy {pg_distrib_dir} to {test_local_pginstall}")
shutil.copytree(
pg_distrib_dir / pg_version.v_prefixed, test_local_pginstall / pg_version.v_prefixed
)
log.info(f"share_dir_path_ext: {share_dir_path_ext}")
neon_env_builder.pg_distrib_dir = test_local_pginstall
log.info(f"local neon_env_builder.pg_distrib_dir: {neon_env_builder.pg_distrib_dir}")
# if file is already in the share_dir_path_ext, skip the test
if os.path.isfile(os.path.join(share_dir_path_ext, "anon.control")):
log.info("anon.control is already in the share_dir_path_ext, skipping the test")
yield False
return
else:
yield True
# cleanup the control file
if os.path.isfile(os.path.join(share_dir_path_ext, "anon.control")):
os.unlink(os.path.join(share_dir_path_ext, "anon.control"))
log.info("anon.control was removed from the share_dir_path_ext")
# remove the extension directory recursively
if os.path.isdir(os.path.join(share_dir_path_ext, "anon")):
directories_to_clean: List[Path] = []
for f in Path(os.path.join(share_dir_path_ext, "anon")).iterdir():
if f.is_file():
log.info(f"Removing file {f}")
f.unlink()
elif f.is_dir():
directories_to_clean.append(f)
for directory_to_clean in reversed(directories_to_clean):
if not os.listdir(directory_to_clean):
log.info(f"Removing empty directory {directory_to_clean}")
directory_to_clean.rmdir()
os.rmdir(os.path.join(share_dir_path_ext, "anon"))
log.info("anon directory was removed from the share_dir_path_ext")
return neon_env_builder
def test_remote_extensions(
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
neon_env_builder_local: NeonEnvBuilder,
httpserver_listen_address,
pg_version,
ext_file_cleanup,
):
if ext_file_cleanup is False:
log.info("test_remote_extensions skipped")
return
if pg_version == PgVersion.V16:
pytest.skip("TODO: PG16 extension building")
@@ -79,7 +51,8 @@ def test_remote_extensions(
(host, port) = httpserver_listen_address
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
archive_path = f"latest/v{pg_version}/extensions/anon.tar.zst"
build_tag = os.environ.get("BUILD_TAG", "latest")
archive_path = f"{build_tag}/v{pg_version}/extensions/anon.tar.zst"
def endpoint_handler_build_tag(request: Request) -> Response:
log.info(f"request: {request}")
@@ -88,6 +61,7 @@ def test_remote_extensions(
file_path = f"test_runner/regress/data/extension_test/5670669815/v{pg_version}/extensions/anon.tar.zst"
file_size = os.path.getsize(file_path)
fh = open(file_path, "rb")
return Response(
fh,
mimetype="application/octet-stream",
@@ -104,12 +78,10 @@ def test_remote_extensions(
# Start a compute node with remote_extension spec
# and check that it can download the extensions and use them to CREATE EXTENSION.
env = neon_env_builder.init_start()
tenant_id, _ = env.neon_cli.create_tenant()
env.neon_cli.create_timeline("test_remote_extensions", tenant_id=tenant_id)
env = neon_env_builder_local.init_start()
env.neon_cli.create_branch("test_remote_extensions")
endpoint = env.endpoints.create(
"test_remote_extensions",
tenant_id=tenant_id,
config_lines=["log_min_messages=debug3"],
)

View File

@@ -282,7 +282,7 @@ def test_deferred_deletion(neon_env_builder: NeonEnvBuilder):
# Now advance the generation in the control plane: subsequent validations
# from the running pageserver will fail. No more deletions should happen.
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
generate_uploads_and_deletions(env, init=False)
assert_deletion_queue(ps_http, lambda n: n > 0)
@@ -397,7 +397,7 @@ def test_deletion_queue_recovery(
if keep_attachment == KeepAttachment.LOSE:
some_other_pageserver = 101010
assert env.attachment_service is not None
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
env.pageserver.start()

View File

@@ -336,10 +336,15 @@ def test_live_reconfig_get_evictions_low_residence_duration_metric_threshold(
):
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
env = neon_env_builder.init_start()
env = neon_env_builder.init_start(
initial_tenant_conf={
# disable compaction so that it will not download the layer for repartitioning
"compaction_period": "0s"
}
)
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
(tenant_id, timeline_id) = env.neon_cli.create_tenant()
(tenant_id, timeline_id) = env.initial_tenant, env.initial_timeline
ps_http = env.pageserver.http_client()
def get_metric():