Compare commits

...

6 Commits

Author SHA1 Message Date
Folke Behrens
431a12acba proxy/conntrack: Global connection tracking table and debug logging 2025-04-21 20:02:10 +02:00
Folke Behrens
fd07ecf58f proxy/conntrack: Add mechanics to track connection state 2025-04-21 19:54:01 +02:00
Heikki Linnakangas
4d0c1e8b78 refactor: Extract some code in pagebench getpage command to function (#11563)
This makes it easier to add a different client implementation alongside
the current one. I started working on a new gRPC-based protocol to
replace the libpq protocol, which will introduce a new function like
`client_libpq`, but for the new protocol.

It's a little more readable with less indentation anyway.
2025-04-19 08:38:03 +00:00
JC Grünhage
3158442a59 fix(ci): set token for fast-forward failure comments and allow merging with state unstable (#11647)
## Problem

https://github.com/neondatabase/neon/actions/runs/14538136318/job/40790985693?pr=11645
failed, even though the relevant parts of the CI had passed and
auto-merge determined the PR is ready to merge. After that, commenting
failed.

## Summary of changes
- set GH_TOKEN for commenting after fast-forward failure
- allow merging with mergeable_state unstable
2025-04-18 17:49:34 +00:00
JC Grünhage
f006879fb7 fix(ci): make regex to find rc branches less strict (#11646)
## Problem

https://github.com/neondatabase/neon/actions/runs/14537161022/job/40787763965
failed to find the correct RC PR run, preventing artifact re-use. This
broke in https://github.com/neondatabase/neon/pull/11547.

There's a hotfix release containing this in
https://github.com/neondatabase/neon/pull/11645.

## Summary of changes
Make the regex for finding the RC PR run less strict, it was needlessly
precise.
2025-04-18 16:39:18 +00:00
Dmitrii Kovalkov
a0d844dfed pageserver + safekeeper: pass ssl ca certs to broker client (#11635)
## Problem
Pageservers and safakeepers do not pass CA certificates to broker
client, so the client do not trust locally issued certificates.
- Part of https://github.com/neondatabase/cloud/issues/27492

## Summary of changes
- Change `ssl_ca_certs` type in PS/SK's config to `Pem` which may be
converted to both `reqwest` and `tonic` certificates.
- Pass CA certificates to storage broker client in PS and SK
2025-04-18 06:27:23 +00:00
24 changed files with 942 additions and 132 deletions

View File

@@ -165,5 +165,5 @@ jobs:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
CURRENT_SHA: ${{ github.sha }}
run: |
RELEASE_PR_RUN_ID=$(gh api "/repos/${GITHUB_REPOSITORY}/actions/runs?head_sha=$CURRENT_SHA" | jq '[.workflow_runs[] | select(.name == "Build and Test") | select(.head_branch | test("^rc/release(-(proxy|compute))?/[0-9]{4}-[0-9]{2}-[0-9]{2}$"; "s"))] | first | .id // ("Failed to find Build and Test run from RC PR!" | halt_error(1))')
RELEASE_PR_RUN_ID=$(gh api "/repos/${GITHUB_REPOSITORY}/actions/runs?head_sha=$CURRENT_SHA" | jq '[.workflow_runs[] | select(.name == "Build and Test") | select(.head_branch | test("^rc/release.*$"; "s"))] | first | .id // ("Failed to find Build and Test run from RC PR!" | halt_error(1))')
echo "release-pr-run-id=$RELEASE_PR_RUN_ID" | tee -a $GITHUB_OUTPUT

View File

@@ -27,15 +27,17 @@ jobs:
- name: Fast forwarding
uses: sequoia-pgp/fast-forward@ea7628bedcb0b0b96e94383ada458d812fca4979
# See https://docs.github.com/en/graphql/reference/enums#mergestatestatus
if: ${{ github.event.pull_request.mergeable_state == 'clean' }}
if: ${{ contains(fromJSON('["clean", "unstable"]'), github.event.pull_request.mergeable_state) }}
with:
merge: true
comment: on-error
github_token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Comment if mergeable_state is not clean
if: ${{ github.event.pull_request.mergeable_state != 'clean' }}
if: ${{ !contains(fromJSON('["clean", "unstable"]'), github.event.pull_request.mergeable_state) }}
env:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
run: |
gh pr comment ${{ github.event.pull_request.number }} \
--repo "${GITHUB_REPOSITORY}" \
--body "Not trying to forward pull-request, because \`mergeable_state\` is \`${{ github.event.pull_request.mergeable_state }}\`, not \`clean\`."
--body "Not trying to forward pull-request, because \`mergeable_state\` is \`${{ github.event.pull_request.mergeable_state }}\`, not \`clean\` or \`unstable\`."

2
Cargo.lock generated
View File

@@ -4285,6 +4285,7 @@ dependencies = [
"pageserver_api",
"pageserver_client",
"pageserver_compaction",
"pem",
"pin-project-lite",
"postgres-protocol",
"postgres-types",
@@ -6001,6 +6002,7 @@ dependencies = [
"once_cell",
"pageserver_api",
"parking_lot 0.12.1",
"pem",
"postgres-protocol",
"postgres_backend",
"postgres_ffi",

View File

@@ -78,6 +78,7 @@ metrics.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that
pageserver_compaction.workspace = true
pem.workspace = true
postgres_connection.workspace = true
postgres_ffi.workspace = true
pq_proto.workspace = true

View File

@@ -68,6 +68,13 @@ pub(crate) struct Args {
targets: Option<Vec<TenantTimelineId>>,
}
/// State shared by all clients
#[derive(Debug)]
struct SharedState {
start_work_barrier: tokio::sync::Barrier,
live_stats: LiveStats,
}
#[derive(Debug, Default)]
struct LiveStats {
completed_requests: AtomicU64,
@@ -240,24 +247,26 @@ async fn main_impl(
all_ranges
};
let live_stats = Arc::new(LiveStats::default());
let num_live_stats_dump = 1;
let num_work_sender_tasks = args.num_clients.get() * timelines.len();
let num_main_impl = 1;
let start_work_barrier = Arc::new(tokio::sync::Barrier::new(
num_live_stats_dump + num_work_sender_tasks + num_main_impl,
));
let shared_state = Arc::new(SharedState {
start_work_barrier: tokio::sync::Barrier::new(
num_live_stats_dump + num_work_sender_tasks + num_main_impl,
),
live_stats: LiveStats::default(),
});
let cancel = CancellationToken::new();
let ss = shared_state.clone();
tokio::spawn({
let stats = Arc::clone(&live_stats);
let start_work_barrier = Arc::clone(&start_work_barrier);
async move {
start_work_barrier.wait().await;
ss.start_work_barrier.wait().await;
loop {
let start = std::time::Instant::now();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let stats = &ss.live_stats;
let completed_requests = stats.completed_requests.swap(0, Ordering::Relaxed);
let missed = stats.missed.swap(0, Ordering::Relaxed);
let elapsed = start.elapsed();
@@ -270,14 +279,12 @@ async fn main_impl(
}
});
let cancel = CancellationToken::new();
let rps_period = args
.per_client_rate
.map(|rps_limit| Duration::from_secs_f64(1.0 / (rps_limit as f64)));
let make_worker: &dyn Fn(WorkerId) -> Pin<Box<dyn Send + Future<Output = ()>>> = &|worker_id| {
let live_stats = live_stats.clone();
let start_work_barrier = start_work_barrier.clone();
let ss = shared_state.clone();
let cancel = cancel.clone();
let ranges: Vec<KeyRange> = all_ranges
.iter()
.filter(|r| r.timeline == worker_id.timeline)
@@ -287,85 +294,8 @@ async fn main_impl(
rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len()))
.unwrap();
let cancel = cancel.clone();
Box::pin(async move {
let client =
pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
.await
.unwrap();
let mut client = client
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
.await
.unwrap();
start_work_barrier.wait().await;
let client_start = Instant::now();
let mut ticks_processed = 0;
let mut inflight = VecDeque::new();
while !cancel.is_cancelled() {
// Detect if a request took longer than the RPS rate
if let Some(period) = &rps_period {
let periods_passed_until_now =
usize::try_from(client_start.elapsed().as_micros() / period.as_micros())
.unwrap();
if periods_passed_until_now > ticks_processed {
live_stats.missed((periods_passed_until_now - ticks_processed) as u64);
}
ticks_processed = periods_passed_until_now;
}
while inflight.len() < args.queue_depth.get() {
let start = Instant::now();
let req = {
let mut rng = rand::thread_rng();
let r = &ranges[weights.sample(&mut rng)];
let key: i128 = rng.gen_range(r.start..r.end);
let key = Key::from_i128(key);
assert!(key.is_rel_block_key());
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
PagestreamGetPageRequest {
hdr: PagestreamRequest {
reqid: 0,
request_lsn: if rng.gen_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn
},
not_modified_since: r.timeline_lsn,
},
rel: rel_tag,
blkno: block_no,
}
};
client.getpage_send(req).await.unwrap();
inflight.push_back(start);
}
let start = inflight.pop_front().unwrap();
client.getpage_recv().await.unwrap();
let end = Instant::now();
live_stats.request_done();
ticks_processed += 1;
STATS.with(|stats| {
stats
.borrow()
.lock()
.unwrap()
.observe(end.duration_since(start))
.unwrap();
});
if let Some(period) = &rps_period {
let next_at = client_start
+ Duration::from_micros(
(ticks_processed) as u64 * u64::try_from(period.as_micros()).unwrap(),
);
tokio::time::sleep_until(next_at.into()).await;
}
}
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
})
};
@@ -387,7 +317,7 @@ async fn main_impl(
};
info!("waiting for everything to become ready");
start_work_barrier.wait().await;
shared_state.start_work_barrier.wait().await;
info!("work started");
if let Some(runtime) = args.runtime {
tokio::time::sleep(runtime.into()).await;
@@ -416,3 +346,91 @@ async fn main_impl(
anyhow::Ok(())
}
async fn client_libpq(
args: &Args,
worker_id: WorkerId,
shared_state: Arc<SharedState>,
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
.await
.unwrap();
let mut client = client
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
.await
.unwrap();
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
let mut ticks_processed = 0;
let mut inflight = VecDeque::new();
while !cancel.is_cancelled() {
// Detect if a request took longer than the RPS rate
if let Some(period) = &rps_period {
let periods_passed_until_now =
usize::try_from(client_start.elapsed().as_micros() / period.as_micros()).unwrap();
if periods_passed_until_now > ticks_processed {
shared_state
.live_stats
.missed((periods_passed_until_now - ticks_processed) as u64);
}
ticks_processed = periods_passed_until_now;
}
while inflight.len() < args.queue_depth.get() {
let start = Instant::now();
let req = {
let mut rng = rand::thread_rng();
let r = &ranges[weights.sample(&mut rng)];
let key: i128 = rng.gen_range(r.start..r.end);
let key = Key::from_i128(key);
assert!(key.is_rel_block_key());
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
PagestreamGetPageRequest {
hdr: PagestreamRequest {
reqid: 0,
request_lsn: if rng.gen_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn
},
not_modified_since: r.timeline_lsn,
},
rel: rel_tag,
blkno: block_no,
}
};
client.getpage_send(req).await.unwrap();
inflight.push_back(start);
}
let start = inflight.pop_front().unwrap();
client.getpage_recv().await.unwrap();
let end = Instant::now();
shared_state.live_stats.request_done();
ticks_processed += 1;
STATS.with(|stats| {
stats
.borrow()
.lock()
.unwrap()
.observe(end.duration_since(start))
.unwrap();
});
if let Some(period) = &rps_period {
let next_at = client_start
+ Duration::from_micros(
(ticks_processed) as u64 * u64::try_from(period.as_micros()).unwrap(),
);
tokio::time::sleep_until(next_at.into()).await;
}
}
}

View File

@@ -416,8 +416,18 @@ fn start_pageserver(
// The storage_broker::connect call needs to happen inside a tokio runtime thread.
let broker_client = WALRECEIVER_RUNTIME
.block_on(async {
let tls_config = storage_broker::ClientTlsConfig::new().ca_certificates(
conf.ssl_ca_certs
.iter()
.map(pem::encode)
.map(storage_broker::Certificate::from_pem),
);
// Note: we do not attempt connecting here (but validate endpoints sanity).
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)
storage_broker::connect(
conf.broker_endpoint.clone(),
conf.broker_keepalive_interval,
tls_config,
)
})
.with_context(|| {
format!(

View File

@@ -17,9 +17,10 @@ use once_cell::sync::OnceCell;
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes};
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pem::Pem;
use postgres_backend::AuthType;
use remote_storage::{RemotePath, RemoteStorageConfig};
use reqwest::{Certificate, Url};
use reqwest::Url;
use storage_broker::Uri;
use utils::id::{NodeId, TimelineId};
use utils::logging::{LogFormat, SecretString};
@@ -67,8 +68,8 @@ pub struct PageServerConf {
/// Period to reload certificate and private key from files.
/// Default: 60s.
pub ssl_cert_reload_period: Duration,
/// Trusted root CA certificates to use in https APIs.
pub ssl_ca_certs: Vec<Certificate>,
/// Trusted root CA certificates to use in https APIs in PEM format.
pub ssl_ca_certs: Vec<Pem>,
/// Current availability zone. Used for traffic metrics.
pub availability_zone: Option<String>,
@@ -497,7 +498,10 @@ impl PageServerConf {
ssl_ca_certs: match ssl_ca_file {
Some(ssl_ca_file) => {
let buf = std::fs::read(ssl_ca_file)?;
Certificate::from_pem_bundle(&buf)?
pem::parse_many(&buf)?
.into_iter()
.filter(|pem| pem.tag() == "CERTIFICATE")
.collect()
}
None => Vec::new(),
},

View File

@@ -8,6 +8,7 @@ use pageserver_api::upcall_api::{
ReAttachRequest, ReAttachResponse, ReAttachResponseTenant, ValidateRequest,
ValidateRequestTenant, ValidateResponse,
};
use reqwest::Certificate;
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio_util::sync::CancellationToken;
@@ -76,8 +77,8 @@ impl StorageControllerUpcallClient {
client = client.default_headers(headers);
}
for ssl_ca_cert in &conf.ssl_ca_certs {
client = client.add_root_certificate(ssl_ca_cert.clone());
for cert in &conf.ssl_ca_certs {
client = client.add_root_certificate(Certificate::from_der(cert.contents())?);
}
Ok(Some(Self {

View File

@@ -1,6 +1,6 @@
//! FIXME: most of this is copy-paste from mgmt_api.rs ; dedupe into a `reqwest_utils::Client` crate.
use pageserver_client::mgmt_api::{Error, ResponseErrorMessageExt};
use reqwest::Method;
use reqwest::{Certificate, Method};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::error;
@@ -34,7 +34,7 @@ impl Client {
};
let mut http_client = reqwest::Client::builder();
for cert in &conf.ssl_ca_certs {
http_client = http_client.add_root_certificate(cert.clone());
http_client = http_client.add_root_certificate(Certificate::from_der(cert.contents())?);
}
let http_client = http_client.build()?;

View File

@@ -24,6 +24,7 @@ use crate::config::{
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::proxy::conntrack::ConnectionTracking;
use crate::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
};
@@ -418,6 +419,8 @@ pub async fn run() -> anyhow::Result<()> {
64,
));
let conntracking = Arc::new(ConnectionTracking::default());
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -431,6 +434,7 @@ pub async fn run() -> anyhow::Result<()> {
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
conntracking.clone(),
));
}
@@ -453,6 +457,7 @@ pub async fn run() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
conntracking.clone(),
));
}
}

View File

@@ -13,6 +13,7 @@ use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
@@ -25,6 +26,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
conntracking: Arc<ConnectionTracking>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -50,6 +52,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
@@ -111,6 +114,7 @@ pub async fn task_main(
socket,
conn_gauge,
cancellations,
conntracking,
)
.instrument(ctx.span())
.boxed()
@@ -167,6 +171,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
@@ -264,6 +269,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
compute: node,
session_id: ctx.session_id(),
cancel: session,
conntracking,
_req: request_gauge,
_conn: conn_gauge,
}))

View File

@@ -0,0 +1,680 @@
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
use std::task::{Context, Poll};
use std::time::SystemTime;
use std::{fmt, io};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ConnId(usize);
#[derive(Default)]
pub struct ConnectionTracking {
conns: clashmap::ClashMap<ConnId, (ConnectionState, SystemTime)>,
}
impl ConnectionTracking {
pub fn new_tracker(self: &Arc<Self>) -> ConnectionTracker<Arc<Self>> {
let conn_id = self.new_conn_id();
ConnectionTracker::new(conn_id, Arc::clone(self))
}
fn new_conn_id(&self) -> ConnId {
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
let id = ConnId(NEXT_ID.fetch_add(1, Ordering::Relaxed));
self.conns
.insert(id, (ConnectionState::Idle, SystemTime::now()));
id
}
fn update(&self, conn_id: ConnId, new_state: ConnectionState) {
let new_timestamp = SystemTime::now();
let old_state = self.conns.insert(conn_id, (new_state, new_timestamp));
if let Some((old_state, _old_timestamp)) = old_state {
tracing::debug!(?conn_id, %old_state, %new_state, "conntrack: update");
} else {
tracing::debug!(?conn_id, %new_state, "conntrack: update");
}
}
fn remove(&self, conn_id: ConnId) {
if let Some((_, (old_state, _old_timestamp))) = self.conns.remove(&conn_id) {
tracing::debug!(?conn_id, %old_state, "conntrack: remove");
}
}
}
impl StateChangeObserver for Arc<ConnectionTracking> {
type ConnId = ConnId;
fn change(
&self,
conn_id: Self::ConnId,
_old_state: ConnectionState,
new_state: ConnectionState,
) {
match new_state {
ConnectionState::Init
| ConnectionState::Idle
| ConnectionState::Transaction
| ConnectionState::Busy
| ConnectionState::Unknown => self.update(conn_id, new_state),
ConnectionState::Closed => self.remove(conn_id),
}
}
}
/// Called by `ConnectionTracker` whenever the `ConnectionState` changed.
pub trait StateChangeObserver {
/// Identifier of the connection passed back on state change.
type ConnId: Copy;
/// Called iff the connection's state changed.
fn change(&self, conn_id: Self::ConnId, old_state: ConnectionState, new_state: ConnectionState);
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum ConnectionState {
#[default]
Init = 0,
Idle = 1,
Transaction = 2,
Busy = 3,
Closed = 4,
Unknown = 5,
}
impl ConnectionState {
const fn into_repr(self) -> u8 {
self as u8
}
const fn from_repr(value: u8) -> Option<Self> {
Some(match value {
0 => Self::Init,
1 => Self::Idle,
2 => Self::Transaction,
3 => Self::Busy,
4 => Self::Closed,
5 => Self::Unknown,
_ => return None,
})
}
}
impl fmt::Display for ConnectionState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
ConnectionState::Init => f.write_str("init"),
ConnectionState::Idle => f.write_str("idle"),
ConnectionState::Transaction => f.write_str("transaction"),
ConnectionState::Busy => f.write_str("busy"),
ConnectionState::Closed => f.write_str("closed"),
ConnectionState::Unknown => f.write_str("unknown"),
}
}
}
/// Stores the `ConnectionState`. Used by ConnectionTracker to avoid needing
/// mutable references.
#[derive(Debug, Default)]
struct AtomicConnectionState(AtomicU8);
impl AtomicConnectionState {
fn set(&self, state: ConnectionState) {
self.0.store(state.into_repr(), Ordering::Relaxed);
}
fn get(&self) -> ConnectionState {
ConnectionState::from_repr(self.0.load(Ordering::Relaxed)).expect("only valid variants")
}
}
/// Tracks the `ConnectionState` of a connection by inspecting the frontend and
/// backend stream and reacting to specific messages. Used in combination with
/// two `TrackedStream`s.
pub struct ConnectionTracker<SCO: StateChangeObserver> {
state: AtomicConnectionState,
observer: SCO,
conn_id: SCO::ConnId,
}
impl<SCO: StateChangeObserver> Drop for ConnectionTracker<SCO> {
fn drop(&mut self) {
self.observer
.change(self.conn_id, self.state.get(), ConnectionState::Closed);
}
}
impl<SCO: StateChangeObserver> ConnectionTracker<SCO> {
pub fn new(conn_id: SCO::ConnId, observer: SCO) -> Self {
ConnectionTracker {
conn_id,
state: AtomicConnectionState::default(),
observer,
}
}
pub fn frontend_message_tag(&self, tag: Tag) {
self.update_state(|old_state| Self::state_from_frontend_tag(old_state, tag));
}
pub fn backend_message_tag(&self, tag: Tag) {
self.update_state(|old_state| Self::state_from_backend_tag(old_state, tag));
}
fn update_state(&self, new_state_fn: impl FnOnce(ConnectionState) -> ConnectionState) {
let old_state = self.state.get();
let new_state = new_state_fn(old_state);
if old_state != new_state {
self.observer.change(self.conn_id, old_state, new_state);
self.state.set(new_state);
}
}
fn state_from_frontend_tag(_old_state: ConnectionState, fe_tag: Tag) -> ConnectionState {
// Most activity from the client puts connection into busy state.
// Only the server can put a connection back into idle state.
match fe_tag {
Tag::Start | Tag::ReadyForQuery(_) | Tag::Message(_) => ConnectionState::Busy,
Tag::End => ConnectionState::Closed,
Tag::Lost => ConnectionState::Unknown,
}
}
fn state_from_backend_tag(old_state: ConnectionState, be_tag: Tag) -> ConnectionState {
match be_tag {
// Check for RFQ and put connection into idle or idle in transaction state.
Tag::ReadyForQuery(b'I') => ConnectionState::Idle,
Tag::ReadyForQuery(b'T') => ConnectionState::Transaction,
Tag::ReadyForQuery(b'E') => ConnectionState::Transaction,
// We can't put a connection into idle state for unknown RFQ status.
Tag::ReadyForQuery(_) => ConnectionState::Unknown,
// Ignore out-fo message from the server.
Tag::NOTICE | Tag::NOTIFICATION_RESPONSE | Tag::PARAMETER_STATUS => old_state,
// All other activity from server puts connection into busy state.
Tag::Start | Tag::Message(_) => ConnectionState::Busy,
Tag::End => ConnectionState::Closed,
Tag::Lost => ConnectionState::Unknown,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Tag {
Message(u8),
ReadyForQuery(u8),
Start,
End,
Lost,
}
impl Tag {
const READY_FOR_QUERY: Tag = Tag::Message(b'Z');
const NOTICE: Tag = Tag::Message(b'N');
const NOTIFICATION_RESPONSE: Tag = Tag::Message(b'A');
const PARAMETER_STATUS: Tag = Tag::Message(b'S');
}
impl fmt::Display for Tag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Tag::Start => f.write_str("start"),
Tag::End => f.write_str("end"),
Tag::Lost => f.write_str("lost"),
Tag::Message(tag) => write!(f, "'{}'", tag as char),
Tag::ReadyForQuery(status) => write!(f, "ReadyForQuery:'{}'", status as char),
}
}
}
pub trait TagObserver {
fn observe(&mut self, tag: Tag);
}
impl<F: FnMut(Tag)> TagObserver for F {
fn observe(&mut self, tag: Tag) {
(self)(tag);
}
}
pin_project! {
pub struct TrackedStream<S, TO> {
#[pin]
stream: S,
scanner: StreamScanner<TO>,
}
}
impl<S: AsyncRead + AsyncWrite + Unpin, TO: TagObserver> TrackedStream<S, TO> {
pub const fn new(stream: S, midstream: bool, observer: TO) -> Self {
TrackedStream {
stream,
scanner: StreamScanner::new(midstream, observer),
}
}
}
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
let old_len = buf.filled().len();
match this.stream.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let new_len = buf.filled().len();
this.scanner.scan_bytes(&buf.filled()[old_len..new_len]);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
impl<S: AsyncWrite + Unpin, TO> AsyncWrite for TrackedStream<S, TO> {
#[inline(always)]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().stream.poll_write(cx, buf)
}
#[inline(always)]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().stream.poll_flush(cx)
}
#[inline(always)]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().stream.poll_shutdown(cx)
}
}
#[derive(Debug)]
struct StreamScanner<TO> {
observer: TO,
state: StreamScannerState,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum StreamScannerState {
/// Initial state when no message has been read and we are looling for a
/// message without a tag.
Start,
/// Read a message tag.
Tag,
/// Read the length bytes and calculate the total length.
Length {
tag: Tag,
/// Number of bytes missing to know the full length of the message: 0..=4
length_bytes_missing: usize,
/// Total length of the message (without tag) that is calculated as we
/// read the bytes for the length.
calculated_length: usize,
},
/// Read (= skip) the payload.
Payload {
tag: Tag,
/// If this is the first time payload bytes are read. Important when
/// inspecting specific messages, like ReadyForQuery.
first: bool,
/// Number of payload bytes left to read before looking for a new tag.
bytes_to_skip: usize,
},
/// Stream was terminated.
End,
/// Stream ended up in a lost state. We only stop tracking the stream, not
/// interrupt it.
Lost,
}
impl<TO: TagObserver> StreamScanner<TO> {
const fn new(midstream: bool, observer: TO) -> Self {
StreamScanner {
observer,
state: if midstream {
StreamScannerState::Tag
} else {
StreamScannerState::Start
},
}
}
}
impl<TO: TagObserver> StreamScanner<TO> {
fn scan_bytes(&mut self, mut buf: &[u8]) {
use StreamScannerState as S;
if matches!(self.state, S::End | S::Lost) {
return;
}
if buf.is_empty() {
match self.state {
S::Start | S::Tag => {
self.observer.observe(Tag::End);
self.state = S::End;
return;
}
S::Length { .. } | S::Payload { .. } => {
self.observer.observe(Tag::Lost);
self.state = S::Lost;
return;
}
S::End | S::Lost => unreachable!(),
}
}
while !buf.is_empty() {
match self.state {
S::Start => {
self.state = S::Length {
tag: Tag::Start,
length_bytes_missing: 4,
calculated_length: 0,
};
}
S::Tag => {
let tag = buf.first().copied().expect("buf not empty");
buf = &buf[1..];
self.state = S::Length {
tag: Tag::Message(tag),
length_bytes_missing: 4,
calculated_length: 0,
};
}
S::Length {
tag,
mut length_bytes_missing,
mut calculated_length,
} => {
let consume = length_bytes_missing.min(buf.len());
let (length_bytes, remainder) = buf.split_at(consume);
for b in length_bytes {
calculated_length <<= 8;
calculated_length |= *b as usize;
}
buf = remainder;
length_bytes_missing -= consume;
if length_bytes_missing == 0 {
let Some(bytes_to_skip) = calculated_length.checked_sub(4) else {
self.observer.observe(Tag::Lost);
self.state = S::Lost;
return;
};
if bytes_to_skip == 0 {
self.observer.observe(tag);
self.state = S::Tag;
} else {
self.state = S::Payload {
tag,
first: true,
bytes_to_skip,
};
}
} else {
self.state = S::Length {
tag,
length_bytes_missing,
calculated_length,
};
}
}
S::Payload {
tag,
first,
mut bytes_to_skip,
} => {
let consume = bytes_to_skip.min(buf.len());
bytes_to_skip -= consume;
if bytes_to_skip == 0 {
if tag == Tag::READY_FOR_QUERY && first && consume == 1 {
let status = buf.first().copied().expect("buf not empty");
self.observer.observe(Tag::ReadyForQuery(status));
} else {
self.observer.observe(tag);
}
self.state = S::Tag;
} else {
self.state = S::Payload {
tag,
first: false,
bytes_to_skip,
};
}
buf = &buf[consume..];
}
S::End | S::Lost => unreachable!(),
}
}
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::pin::pin;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, BufReader};
use super::*;
#[test]
fn test_stream_scanner() {
let tags = Rc::new(RefCell::new(Vec::new()));
let observer_tags = tags.clone();
let observer = move |tag| {
observer_tags.borrow_mut().push(tag);
};
let mut scanner = StreamScanner::new(false, observer);
scanner.scan_bytes(&[0, 0]);
assert_eq!(tags.borrow().as_slice(), &[]);
assert_eq!(
scanner.state,
StreamScannerState::Length {
tag: Tag::Start,
length_bytes_missing: 2,
calculated_length: 0,
}
);
scanner.scan_bytes(&[0x01, 0x01, 0x00]);
assert_eq!(tags.borrow().as_slice(), &[]);
assert_eq!(
scanner.state,
StreamScannerState::Payload {
tag: Tag::Start,
first: false,
bytes_to_skip: 0x00000101 - 4 - 1,
}
);
scanner.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice());
assert_eq!(tags.borrow().as_slice(), &[]);
assert_eq!(
scanner.state,
StreamScannerState::Payload {
tag: Tag::Start,
first: false,
bytes_to_skip: 1,
}
);
scanner.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08]);
assert_eq!(tags.borrow().as_slice(), &[Tag::Start]);
assert_eq!(
scanner.state,
StreamScannerState::Payload {
tag: Tag::Message(b'A'),
first: true,
bytes_to_skip: 4,
}
);
scanner.scan_bytes(&[0, 0, 0, 0]);
assert_eq!(tags.borrow().as_slice(), &[Tag::Start, Tag::Message(b'A')]);
assert_eq!(scanner.state, StreamScannerState::Tag);
scanner.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T']);
assert_eq!(
tags.borrow().as_slice(),
&[Tag::Start, Tag::Message(b'A'), Tag::ReadyForQuery(b'T')]
);
assert_eq!(scanner.state, StreamScannerState::Tag);
scanner.scan_bytes(&[]);
assert_eq!(
tags.borrow().as_slice(),
&[
Tag::Start,
Tag::Message(b'A'),
Tag::ReadyForQuery(b'T'),
Tag::End
]
);
assert_eq!(scanner.state, StreamScannerState::End);
}
#[tokio::test]
async fn test_connection_tracker() {
let transitions: Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>> = Arc::default();
struct Observer(Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>>);
impl StateChangeObserver for Observer {
type ConnId = usize;
fn change(
&self,
conn_id: Self::ConnId,
old_state: ConnectionState,
new_state: ConnectionState,
) {
assert_eq!(conn_id, 42);
self.0.lock().unwrap().push((old_state, new_state));
}
}
let tracker = ConnectionTracker::new(42, Observer(transitions.clone()));
let stream = TestStream::new(
&[
0, 0, 0, 4, // Init
b'Z', 0, 0, 0, 5, b'I', // Init -> Idle
b'x', 0, 0, 0, 4, // Idle -> Busy
b'Z', 0, 0, 0, 5, b'I', // Busy -> Idle
][..],
);
// AsyncRead
let mut stream = TrackedStream::new(stream, false, |tag| tracker.backend_message_tag(tag));
let mut readbuf = [0; 2];
let n = stream.read_exact(&mut readbuf).await.unwrap();
assert_eq!(n, 2);
assert_eq!(&readbuf, &[0, 0,]);
assert!(transitions.lock().unwrap().is_empty());
let mut readbuf = [0; 2];
let n = stream.read_exact(&mut readbuf).await.unwrap();
assert_eq!(n, 2);
assert_eq!(&readbuf, &[0, 4]);
assert_eq!(
transitions.lock().unwrap().as_slice(),
&[(ConnectionState::Init, ConnectionState::Busy)]
);
let mut readbuf = [0; 6];
let n = stream.read_exact(&mut readbuf).await.unwrap();
assert_eq!(n, 6);
assert_eq!(&readbuf, &[b'Z', 0, 0, 0, 5, b'I']);
assert_eq!(
transitions.lock().unwrap().as_slice(),
&[
(ConnectionState::Init, ConnectionState::Busy),
(ConnectionState::Busy, ConnectionState::Idle),
]
);
let mut readbuf = [0; 5];
let n = stream.read_exact(&mut readbuf).await.unwrap();
assert_eq!(n, 5);
assert_eq!(&readbuf, &[b'x', 0, 0, 0, 4]);
assert_eq!(
transitions.lock().unwrap().as_slice(),
&[
(ConnectionState::Init, ConnectionState::Busy),
(ConnectionState::Busy, ConnectionState::Idle),
(ConnectionState::Idle, ConnectionState::Busy),
]
);
let mut readbuf = [0; 6];
let n = stream.read_exact(&mut readbuf).await.unwrap();
assert_eq!(n, 6);
assert_eq!(&readbuf, &[b'Z', 0, 0, 0, 5, b'I']);
assert_eq!(
transitions.lock().unwrap().as_slice(),
&[
(ConnectionState::Init, ConnectionState::Busy),
(ConnectionState::Busy, ConnectionState::Idle),
(ConnectionState::Idle, ConnectionState::Busy),
(ConnectionState::Busy, ConnectionState::Idle),
]
);
}
struct TestStream {
stream: BufReader<&'static [u8]>,
}
impl TestStream {
fn new(data: &'static [u8]) -> Self {
TestStream {
stream: BufReader::new(data),
}
}
}
impl AsyncRead for TestStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
pin!(&mut self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for TestStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
}

View File

@@ -2,6 +2,7 @@
mod tests;
pub(crate) mod connect_compute;
pub mod conntrack;
mod copy_bidirectional;
pub(crate) mod handshake;
pub(crate) mod passthrough;
@@ -30,6 +31,7 @@ use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -60,6 +62,7 @@ pub async fn task_main(
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conntracking: Arc<ConnectionTracking>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -85,6 +88,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
@@ -149,6 +153,7 @@ pub async fn task_main(
endpoint_rate_limiter2,
conn_gauge,
cancellations,
conntracking,
)
.instrument(ctx.span())
.boxed()
@@ -268,6 +273,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
@@ -409,6 +415,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
compute: node,
session_id: ctx.session_id(),
cancel: session,
conntracking,
_req: request_gauge,
_conn: conn_gauge,
}))

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
@@ -9,6 +11,7 @@ use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::proxy::conntrack::{ConnectionTracking, TrackedStream};
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
@@ -19,6 +22,7 @@ pub(crate) async fn proxy_pass(
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
private_link_id: Option<SmolStr>,
conntracking: &Arc<ConnectionTracking>,
) -> Result<(), ErrorSource> {
// we will report ingress at a later date
let usage_tx = USAGE_METRICS.register(Ids {
@@ -27,9 +31,11 @@ pub(crate) async fn proxy_pass(
private_link_id,
});
let conn_tracker = conntracking.new_tracker();
let metrics = &Metrics::get().proxy.io_bytes;
let m_sent = metrics.with_labels(Direction::Tx);
let mut client = MeasuredStream::new(
let client = MeasuredStream::new(
client,
|_| {},
|cnt| {
@@ -38,9 +44,10 @@ pub(crate) async fn proxy_pass(
usage_tx.record_egress(cnt as u64);
},
);
let mut client = TrackedStream::new(client, true, |tag| conn_tracker.frontend_message_tag(tag));
let m_recv = metrics.with_labels(Direction::Rx);
let mut compute = MeasuredStream::new(
let compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
@@ -49,6 +56,8 @@ pub(crate) async fn proxy_pass(
usage_tx.record_ingress(cnt as u64);
},
);
let mut compute =
TrackedStream::new(compute, true, |tag| conn_tracker.backend_message_tag(tag));
// Starting from here we only proxy the client's traffic.
debug!("performing the proxy pass...");
@@ -68,6 +77,7 @@ pub(crate) struct ProxyPassthrough<S> {
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
pub(crate) conntracking: Arc<ConnectionTracking>,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
@@ -83,6 +93,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
self.compute.stream,
self.aux,
self.private_link_id,
&self.conntracking,
)
.await;
if let Err(err) = self

View File

@@ -50,6 +50,7 @@ use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
use crate::protocol2::{ChainRW, ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
@@ -124,6 +125,9 @@ pub async fn task_main(
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
let conntracking = Arc::new(ConnectionTracking::default());
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -153,6 +157,8 @@ pub async fn task_main(
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
@@ -185,6 +191,7 @@ pub async fn task_main(
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conntracking,
conn,
conn_info,
session_id,
@@ -309,6 +316,7 @@ async fn connection_handler(
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conntracking: Arc<ConnectionTracking>,
conn: AsyncRW,
conn_info: ConnectionInfo,
session_id: uuid::Uuid,
@@ -347,6 +355,7 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
let handler = connections.spawn(
request_handler(
req,
@@ -359,6 +368,7 @@ async fn connection_handler(
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
conntracking,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -407,6 +417,7 @@ async fn request_handler(
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -452,6 +463,7 @@ async fn request_handler(
endpoint_rate_limiter,
host,
cancellations,
conntracking,
)
.await
{

View File

@@ -17,6 +17,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
@@ -133,6 +134,7 @@ pub(crate) async fn serve_websocket(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
@@ -152,6 +154,7 @@ pub(crate) async fn serve_websocket(
endpoint_rate_limiter,
conn_gauge,
cancellations,
conntracking,
))
.await;

View File

@@ -55,6 +55,7 @@ tokio-util = { workspace = true }
tracing.workspace = true
url.workspace = true
metrics.workspace = true
pem.workspace = true
postgres_backend.workspace = true
postgres_ffi.workspace = true
pq_proto.workspace = true

View File

@@ -16,7 +16,6 @@ use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use metrics::set_build_info_metric;
use remote_storage::RemoteStorageConfig;
use reqwest::Certificate;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT, DEFAULT_HEARTBEAT_TIMEOUT,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES, DEFAULT_PARTIAL_BACKUP_CONCURRENCY,
@@ -373,7 +372,10 @@ async fn main() -> anyhow::Result<()> {
Some(ssl_ca_file) => {
tracing::info!("Using ssl root CA file: {ssl_ca_file:?}");
let buf = tokio::fs::read(ssl_ca_file).await?;
Certificate::from_pem_bundle(&buf)?
pem::parse_many(&buf)?
.into_iter()
.filter(|pem| pem.tag() == "CERTIFICATE")
.collect()
}
None => Vec::new(),
};

View File

@@ -24,6 +24,15 @@ use crate::{GlobalTimelines, SafeKeeperConf};
const RETRY_INTERVAL_MSEC: u64 = 1000;
const PUSH_INTERVAL_MSEC: u64 = 1000;
fn make_tls_config(conf: &SafeKeeperConf) -> storage_broker::ClientTlsConfig {
storage_broker::ClientTlsConfig::new().ca_certificates(
conf.ssl_ca_certs
.iter()
.map(pem::encode)
.map(storage_broker::Certificate::from_pem),
)
}
/// Push once in a while data about all active timelines to the broker.
async fn push_loop(
conf: Arc<SafeKeeperConf>,
@@ -37,8 +46,11 @@ async fn push_loop(
let active_timelines_set = global_timelines.get_global_broker_active_set();
let mut client =
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)?;
let mut client = storage_broker::connect(
conf.broker_endpoint.clone(),
conf.broker_keepalive_interval,
make_tls_config(&conf),
)?;
let push_interval = Duration::from_millis(PUSH_INTERVAL_MSEC);
let outbound = async_stream::stream! {
@@ -81,8 +93,11 @@ async fn pull_loop(
global_timelines: Arc<GlobalTimelines>,
stats: Arc<BrokerStats>,
) -> Result<()> {
let mut client =
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)?;
let mut client = storage_broker::connect(
conf.broker_endpoint.clone(),
conf.broker_keepalive_interval,
make_tls_config(&conf),
)?;
// TODO: subscribe only to local timelines instead of all
let request = SubscribeSafekeeperInfoRequest {
@@ -134,8 +149,11 @@ async fn discover_loop(
global_timelines: Arc<GlobalTimelines>,
stats: Arc<BrokerStats>,
) -> Result<()> {
let mut client =
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)?;
let mut client = storage_broker::connect(
conf.broker_endpoint.clone(),
conf.broker_keepalive_interval,
make_tls_config(&conf),
)?;
let request = SubscribeByFilterRequest {
types: vec![TypeSubscription {

View File

@@ -14,6 +14,7 @@ use http_utils::json::{json_request, json_response};
use http_utils::request::{ensure_no_body, parse_query_param, parse_request_param};
use http_utils::{RequestExt, RouterBuilder};
use hyper::{Body, Request, Response, StatusCode};
use pem::Pem;
use postgres_ffi::WAL_SEGMENT_SIZE;
use safekeeper_api::models::{
AcceptorStateStatus, PullTimelineRequest, SafekeeperStatus, SkTimelineInfo, TenantDeleteResult,
@@ -230,14 +231,20 @@ async fn timeline_pull_handler(mut request: Request<Body>) -> Result<Response<Bo
let conf = get_conf(&request);
let global_timelines = get_global_timelines(&request);
let resp = pull_timeline::handle_request(
data,
conf.sk_auth_token.clone(),
conf.ssl_ca_certs.clone(),
global_timelines,
)
.await
.map_err(ApiError::InternalServerError)?;
let ca_certs = conf
.ssl_ca_certs
.iter()
.map(Pem::contents)
.map(reqwest::Certificate::from_der)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
ApiError::InternalServerError(anyhow::anyhow!("failed to parse CA certs: {e}"))
})?;
let resp =
pull_timeline::handle_request(data, conf.sk_auth_token.clone(), ca_certs, global_timelines)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, resp)
}

View File

@@ -6,8 +6,8 @@ use std::time::Duration;
use camino::Utf8PathBuf;
use once_cell::sync::Lazy;
use pem::Pem;
use remote_storage::RemoteStorageConfig;
use reqwest::Certificate;
use storage_broker::Uri;
use tokio::runtime::Runtime;
use utils::auth::SwappableJwtAuth;
@@ -120,7 +120,7 @@ pub struct SafeKeeperConf {
pub ssl_key_file: Utf8PathBuf,
pub ssl_cert_file: Utf8PathBuf,
pub ssl_cert_reload_period: Duration,
pub ssl_ca_certs: Vec<Certificate>,
pub ssl_ca_certs: Vec<Pem>,
pub use_https_safekeeper_api: bool,
}

View File

@@ -8,6 +8,7 @@ use std::time::SystemTime;
use anyhow::{Context, bail};
use futures::StreamExt;
use postgres_protocol::message::backend::ReplicationMessage;
use reqwest::Certificate;
use safekeeper_api::Term;
use safekeeper_api::membership::INVALID_GENERATION;
use safekeeper_api::models::{PeerInfo, TimelineStatus};
@@ -241,7 +242,7 @@ async fn recover(
let mut client = reqwest::Client::builder();
for cert in &conf.ssl_ca_certs {
client = client.add_root_certificate(cert.clone());
client = client.add_root_certificate(Certificate::from_der(cert.contents())?);
}
let client = client
.build()

View File

@@ -87,7 +87,12 @@ fn tli_from_u64(i: u64) -> Vec<u8> {
async fn subscribe(client: Option<BrokerClientChannel>, counter: Arc<AtomicU64>, i: u64) {
let mut client = match client {
Some(c) => c,
None => storage_broker::connect(DEFAULT_ENDPOINT, Duration::from_secs(5)).unwrap(),
None => storage_broker::connect(
DEFAULT_ENDPOINT,
Duration::from_secs(5),
storage_broker::ClientTlsConfig::new(),
)
.unwrap(),
};
let ttid = ProtoTenantTimelineId {
@@ -119,7 +124,12 @@ async fn subscribe(client: Option<BrokerClientChannel>, counter: Arc<AtomicU64>,
async fn publish(client: Option<BrokerClientChannel>, n_keys: u64) {
let mut client = match client {
Some(c) => c,
None => storage_broker::connect(DEFAULT_ENDPOINT, Duration::from_secs(5)).unwrap(),
None => storage_broker::connect(
DEFAULT_ENDPOINT,
Duration::from_secs(5),
storage_broker::ClientTlsConfig::new(),
)
.unwrap(),
};
let mut counter: u64 = 0;
@@ -164,7 +174,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
let h = tokio::spawn(progress_reporter(counters.clone()));
let c = storage_broker::connect(DEFAULT_ENDPOINT, Duration::from_secs(5)).unwrap();
let c = storage_broker::connect(
DEFAULT_ENDPOINT,
Duration::from_secs(5),
storage_broker::ClientTlsConfig::new(),
)
.unwrap();
for i in 0..args.num_subs {
let c = Some(c.clone());

View File

@@ -4,7 +4,7 @@ use proto::TenantTimelineId as ProtoTenantTimelineId;
use proto::broker_service_client::BrokerServiceClient;
use tonic::Status;
use tonic::codegen::StdError;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tonic::transport::{Channel, Endpoint};
use utils::id::{TenantId, TenantTimelineId, TimelineId};
// Code generated by protobuf.
@@ -20,6 +20,7 @@ pub mod metrics;
// Re-exports to avoid direct tonic dependency in user crates.
pub use hyper::Uri;
pub use tonic::transport::{Certificate, ClientTlsConfig};
pub use tonic::{Code, Request, Streaming};
pub const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:50051";
@@ -38,7 +39,11 @@ pub type BrokerClientChannel = BrokerServiceClient<Channel>;
//
// NB: this function is not async, but still must be run on a tokio runtime thread
// because that's a requirement of tonic_endpoint.connect_lazy()'s Channel::new call.
pub fn connect<U>(endpoint: U, keepalive_interval: Duration) -> anyhow::Result<BrokerClientChannel>
pub fn connect<U>(
endpoint: U,
keepalive_interval: Duration,
tls_config: ClientTlsConfig,
) -> anyhow::Result<BrokerClientChannel>
where
U: std::convert::TryInto<Uri>,
U::Error: std::error::Error + Send + Sync + 'static,
@@ -54,8 +59,7 @@ where
rustls::crypto::ring::default_provider()
.install_default()
.ok();
let tls = ClientTlsConfig::new();
tonic_endpoint = tonic_endpoint.tls_config(tls)?;
tonic_endpoint = tonic_endpoint.tls_config(tls_config)?;
}
tonic_endpoint = tonic_endpoint
.http2_keep_alive_interval(keepalive_interval)