Merge commit '84a2556c9' into problame/standby-horizon-leases

This commit is contained in:
Christian Schwarz
2025-08-06 17:58:54 +02:00
29 changed files with 736 additions and 305 deletions

View File

@@ -133,7 +133,7 @@ RUN case $DEBIAN_VERSION in \
# Install newer version (3.25) from backports.
# libstdc++-10-dev is required for plv8
bullseye) \
echo "deb http://deb.debian.org/debian bullseye-backports main" > /etc/apt/sources.list.d/bullseye-backports.list; \
echo "deb http://archive.debian.org/debian bullseye-backports main" > /etc/apt/sources.list.d/bullseye-backports.list; \
VERSION_INSTALLS="cmake/bullseye-backports cmake-data/bullseye-backports libstdc++-10-dev"; \
;; \
# Version-specific installs for Bookworm (PG17):

View File

@@ -0,0 +1,60 @@
use metrics::{
IntCounter, IntGaugeVec, core::Collector, proto::MetricFamily, register_int_counter,
register_int_gauge_vec,
};
use once_cell::sync::Lazy;
// Counter keeping track of the number of PageStream request errors reported by Postgres.
// An error is registered every time Postgres calls compute_ctl's /refresh_configuration API.
// Postgres will invoke this API if it detected trouble with PageStream requests (get_page@lsn,
// get_base_backup, etc.) it sends to any pageserver. An increase in this counter value typically
// indicates Postgres downtime, as PageStream requests are critical for Postgres to function.
pub static POSTGRES_PAGESTREAM_REQUEST_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pg_cctl_pagestream_request_errors_total",
"Number of PageStream request errors reported by the postgres process"
)
.expect("failed to define a metric")
});
// Counter keeping track of the number of compute configuration errors due to Postgres statement
// timeouts. An error is registered every time `ComputeNode::reconfigure()` fails due to Postgres
// error code 57014 (query cancelled). This statement timeout typically occurs when postgres is
// stuck in a problematic retry loop when the PS is reject its connection requests (usually due
// to PG pointing at the wrong PS). We should investigate the root cause when this counter value
// increases by checking PG and PS logs.
pub static COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pg_cctl_configure_statement_timeout_errors_total",
"Number of compute configuration errors due to Postgres statement timeouts."
)
.expect("failed to define a metric")
});
pub static COMPUTE_ATTACHED: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!(
"pg_cctl_attached",
"Compute node attached status (1 if attached)",
&[
"pg_compute_id",
"pg_instance_id",
"tenant_id",
"timeline_id"
]
)
.expect("failed to define a metric")
});
pub fn collect() -> Vec<MetricFamily> {
let mut metrics = Vec::new();
metrics.extend(POSTGRES_PAGESTREAM_REQUEST_ERRORS.collect());
metrics.extend(COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.collect());
metrics.extend(COMPUTE_ATTACHED.collect());
metrics
}
pub fn initialize_metrics() {
Lazy::force(&POSTGRES_PAGESTREAM_REQUEST_ERRORS);
Lazy::force(&COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS);
Lazy::force(&COMPUTE_ATTACHED);
}

View File

@@ -16,6 +16,7 @@ pub mod compute_prewarm;
pub mod compute_promote;
pub mod disk_quota;
pub mod extension_server;
pub mod hadron_metrics;
pub mod installed_extensions;
pub mod local_proxy;
pub mod lsn_lease;

View File

@@ -11,9 +11,8 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use crate::connect::connect;
use crate::connect_raw::{RawConnection, connect_raw};
use crate::connect_raw::{self, StartupStream};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream};
use crate::{Client, Connection, Error};
@@ -244,24 +243,26 @@ impl Config {
&self,
stream: S,
tls: T,
) -> Result<RawConnection<S, T::Stream>, Error>
) -> Result<StartupStream<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let stream = connect_tls(stream, self.ssl_mode, tls).await?;
connect_raw(stream, self).await
let mut stream = StartupStream::new(stream);
connect_raw::startup(&mut stream, self).await?;
connect_raw::authenticate(&mut stream, self).await?;
Ok(stream)
}
pub async fn authenticate<S, T>(
&self,
stream: MaybeTlsStream<S, T>,
) -> Result<RawConnection<S, T>, Error>
pub async fn authenticate<S, T>(&self, stream: &mut StartupStream<S, T>) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
connect_raw(stream, self).await
connect_raw::startup(stream, self).await?;
connect_raw::authenticate(stream, self).await
}
}

View File

@@ -1,15 +1,17 @@
use std::net::IpAddr;
use futures_util::TryStreamExt;
use postgres_protocol2::message::backend::Message;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_raw::StartupStream;
use crate::connect_socket::connect_socket;
use crate::connect_tls::connect_tls;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
use crate::{Client, Config, Connection, Error};
pub async fn connect<T>(
tls: &T,
@@ -43,14 +45,8 @@ where
T: TlsConnect<TcpStream>,
{
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters: _,
delayed_notice: _,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
let mut stream = config.tls_and_authenticate(socket, tls).await?;
let (process_id, secret_key) = wait_until_ready(&mut stream).await?;
let socket_config = SocketConfig {
host_addr,
@@ -70,7 +66,32 @@ where
secret_key,
);
let stream = stream.into_framed();
let connection = Connection::new(stream, conn_tx, conn_rx);
Ok((client, connection))
}
async fn wait_until_ready<S, T>(stream: &mut StartupStream<S, T>) -> Result<(i32, i32), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut process_id = 0;
let mut secret_key = 0;
loop {
match stream.try_next().await.map_err(Error::io)? {
Some(Message::BackendKeyData(body)) => {
process_id = body.process_id();
secret_key = body.secret_key();
}
// These values are currently not used by `Client`/`Connection`. Ignore them.
Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {}
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
}
}

View File

@@ -1,28 +1,26 @@
use std::collections::HashMap;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{Context, Poll, ready};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use futures_util::{Sink, SinkExt, Stream, TryStreamExt};
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::codec::{Framed, FramedParts, FramedWrite};
use crate::Error;
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
use crate::codec::PostgresCodec;
use crate::config::{self, AuthKeys, Config};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::TlsStream;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed_notice: Vec<NoticeResponseBody>,
inner: FramedWrite<MaybeTlsStream<S, T>, PostgresCodec>,
read_buf: BytesMut,
}
impl<S, T> Sink<Bytes> for StartupStream<S, T>
@@ -56,63 +54,93 @@ where
{
type Item = io::Result<Message>;
fn poll_next(
mut self: Pin<&mut Self>,
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// read 1 byte tag, 4 bytes length.
let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
if len < 4 {
return Poll::Ready(Some(Err(std::io::Error::other(
"postgres message too small",
))));
}
if len >= 65536 {
return Poll::Ready(Some(Err(std::io::Error::other(
"postgres message too large",
))));
}
// the tag is an additional byte.
let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
// Message::parse will remove the all the bytes from the buffer.
Poll::Ready(Message::parse(&mut self.read_buf).transpose())
}
}
impl<S, T> StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
/// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
///
/// If the current buffer length is greater, nothing happens.
fn poll_fill_buf_exact(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<io::Result<Message>>> {
loop {
match self.buf.next() {
Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
Ok(None) => {}
Err(e) => return Poll::Ready(Some(Err(e))),
len: usize,
) -> Poll<Result<&[u8], std::io::Error>> {
let this = self.get_mut();
let mut stream = Pin::new(this.inner.get_mut());
let mut n = this.read_buf.len();
while n < len {
this.read_buf.resize(len, 0);
let mut buf = ReadBuf::new(&mut this.read_buf[..]);
buf.set_filled(n);
if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
this.read_buf.truncate(n);
return Poll::Pending;
}
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
None => return Poll::Ready(None),
if buf.filled().len() == n {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"early eof",
)));
}
n = buf.filled().len();
this.read_buf.truncate(n);
}
Poll::Ready(Ok(&this.read_buf[..len]))
}
pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
let write_buf = std::mem::take(self.inner.write_buffer_mut());
let io = self.inner.into_inner();
let mut parts = FramedParts::new(io, PostgresCodec);
parts.read_buf = self.read_buf;
parts.write_buf = write_buf;
Framed::from_parts(parts)
}
pub fn new(io: MaybeTlsStream<S, T>) -> Self {
Self {
inner: FramedWrite::new(io, PostgresCodec),
read_buf: BytesMut::new(),
}
}
}
pub struct RawConnection<S, T> {
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pub parameters: HashMap<String, String>,
pub delayed_notice: Vec<NoticeResponseBody>,
pub process_id: i32,
pub secret_key: i32,
}
pub async fn connect_raw<S, T>(
stream: MaybeTlsStream<S, T>,
pub(crate) async fn startup<S, T>(
stream: &mut StartupStream<S, T>,
config: &Config,
) -> Result<RawConnection<S, T>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
let mut stream = StartupStream {
inner: Framed::new(stream, PostgresCodec),
buf: BackendMessages::empty(),
delayed_notice: Vec::new(),
};
startup(&mut stream, config).await?;
authenticate(&mut stream, config).await?;
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
Ok(RawConnection {
stream: stream.inner,
parameters,
delayed_notice: stream.delayed_notice,
process_id,
secret_key,
})
}
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
@@ -123,7 +151,10 @@ where
stream.send(buf.freeze()).await.map_err(Error::io)
}
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
pub(crate) async fn authenticate<S, T>(
stream: &mut StartupStream<S, T>,
config: &Config,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
@@ -278,35 +309,3 @@ where
Ok(())
}
async fn read_info<S, T>(
stream: &mut StartupStream<S, T>,
) -> Result<(i32, i32, HashMap<String, String>), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut process_id = 0;
let mut secret_key = 0;
let mut parameters = HashMap::new();
loop {
match stream.try_next().await.map_err(Error::io)? {
Some(Message::BackendKeyData(body)) => {
process_id = body.process_id();
secret_key = body.secret_key();
}
Some(Message::ParameterStatus(body)) => {
parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
}
}

View File

@@ -452,16 +452,16 @@ impl Error {
Error(Box::new(ErrorInner { kind, cause }))
}
pub(crate) fn closed() -> Error {
pub fn closed() -> Error {
Error::new(Kind::Closed, None)
}
pub(crate) fn unexpected_message() -> Error {
pub fn unexpected_message() -> Error {
Error::new(Kind::UnexpectedMessage, None)
}
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn db(error: ErrorResponseBody) -> Error {
pub fn db(error: ErrorResponseBody) -> Error {
match DbError::parse(&mut error.fields()) {
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
@@ -493,7 +493,7 @@ impl Error {
Error::new(Kind::Tls, Some(e))
}
pub(crate) fn io(e: io::Error) -> Error {
pub fn io(e: io::Error) -> Error {
Error::new(Kind::Io, Some(Box::new(e)))
}

View File

@@ -6,7 +6,6 @@ use postgres_protocol2::message::backend::ReadyForQueryBody;
pub use crate::cancel_token::{CancelToken, RawCancelToken};
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
@@ -50,7 +49,7 @@ mod client;
mod codec;
pub mod config;
mod connect;
mod connect_raw;
pub mod connect_raw;
mod connect_socket;
mod connect_tls;
mod connection;

View File

@@ -301,7 +301,12 @@ pub struct PullTimelineRequest {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub http_hosts: Vec<String>,
pub ignore_tombstone: Option<bool>,
/// Membership configuration to switch to after pull.
/// It guarantees that if pull_timeline returns successfully, the timeline will
/// not be deleted by request with an older generation.
/// Storage controller always sets this field.
/// None is only allowed for manual pull_timeline requests.
pub mconf: Option<Configuration>,
}
#[derive(Debug, Serialize, Deserialize)]

View File

@@ -178,6 +178,8 @@ static PageServer page_servers[MAX_SHARDS];
static bool pageserver_flush(shardno_t shard_no);
static void pageserver_disconnect(shardno_t shard_no);
static void pageserver_disconnect_shard(shardno_t shard_no);
// HADRON
shardno_t get_num_shards(void);
static bool
PagestoreShmemIsValid(void)
@@ -286,6 +288,22 @@ AssignPageserverConnstring(const char *newval, void *extra)
}
}
/* BEGIN_HADRON */
/**
* Return the total number of shards seen in the shard map.
*/
shardno_t get_num_shards(void)
{
const ShardMap *shard_map;
Assert(pagestore_shared);
shard_map = &pagestore_shared->shard_map;
Assert(shard_map != NULL);
return shard_map->num_shards;
}
/* END_HADRON */
/*
* Get the current number of shards, and/or the connection string for a
* particular shard from the shard map in shared memory.

View File

@@ -110,6 +110,9 @@ static void rm_safekeeper_event_set(Safekeeper *to_remove, bool is_sk);
static void CheckGracefulShutdown(WalProposer *wp);
// HADRON
shardno_t get_num_shards(void);
static void
init_walprop_config(bool syncSafekeepers)
{
@@ -646,18 +649,19 @@ walprop_pg_get_shmem_state(WalProposer *wp)
* Record new ps_feedback in the array with shards and update min_feedback.
*/
static PageserverFeedback
record_pageserver_feedback(PageserverFeedback *ps_feedback)
record_pageserver_feedback(PageserverFeedback *ps_feedback, shardno_t num_shards)
{
PageserverFeedback min_feedback;
Assert(ps_feedback->present);
Assert(ps_feedback->shard_number < MAX_SHARDS);
Assert(ps_feedback->shard_number < num_shards);
SpinLockAcquire(&walprop_shared->mutex);
/* Update the number of shards */
if (ps_feedback->shard_number + 1 > walprop_shared->num_shards)
walprop_shared->num_shards = ps_feedback->shard_number + 1;
// Hadron: Update the num_shards from the source-of-truth (shard map) lazily when we receive
// a new pageserver feedback.
walprop_shared->num_shards = Max(walprop_shared->num_shards, num_shards);
/* Update the feedback */
memcpy(&walprop_shared->shard_ps_feedback[ps_feedback->shard_number], ps_feedback, sizeof(PageserverFeedback));
@@ -2023,19 +2027,43 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk)
if (wp->config->syncSafekeepers)
return;
/* handle fresh ps_feedback */
if (sk->appendResponse.ps_feedback.present)
{
PageserverFeedback min_feedback = record_pageserver_feedback(&sk->appendResponse.ps_feedback);
shardno_t num_shards = get_num_shards();
/* Only one main shard sends non-zero currentClusterSize */
if (sk->appendResponse.ps_feedback.currentClusterSize > 0)
SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
if (min_feedback.disk_consistent_lsn != standby_apply_lsn)
// During shard split, we receive ps_feedback from child shards before
// the split commits and our shard map GUC has been updated. We must
// filter out such feedback here because record_pageserver_feedback()
// doesn't do it.
//
// NB: what we would actually want to happen is that we only receive
// ps_feedback from the parent shards when the split is committed, then
// apply the split to our set of tracked feedback and from here on only
// receive ps_feedback from child shards. This filter condition doesn't
// do that: if we split from N parent to 2N child shards, the first N
// child shards' feedback messages will pass this condition, even before
// the split is committed. That's a bit sloppy, but OK for now.
if (sk->appendResponse.ps_feedback.shard_number < num_shards)
{
standby_apply_lsn = min_feedback.disk_consistent_lsn;
needToAdvanceSlot = true;
PageserverFeedback min_feedback = record_pageserver_feedback(&sk->appendResponse.ps_feedback, num_shards);
/* Only one main shard sends non-zero currentClusterSize */
if (sk->appendResponse.ps_feedback.currentClusterSize > 0)
SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
if (min_feedback.disk_consistent_lsn != standby_apply_lsn)
{
standby_apply_lsn = min_feedback.disk_consistent_lsn;
needToAdvanceSlot = true;
}
}
else
{
// HADRON
elog(DEBUG2, "Ignoring pageserver feedback for unknown shard %d (current shard number %d)",
sk->appendResponse.ps_feedback.shard_number, num_shards);
}
}

View File

@@ -429,26 +429,13 @@ impl CancellationHandler {
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
pub socket_addr: SocketAddr,
pub cancel_token: RawCancelToken,
pub hostname: String, // for pg_sni router
pub user_info: ComputeUserInfo,
}
impl CancelClosure {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
hostname: String,
user_info: ComputeUserInfo,
) -> Self {
Self {
socket_addr,
cancel_token,
hostname,
user_info,
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(
&self,

View File

@@ -7,17 +7,15 @@ use std::net::{IpAddr, SocketAddr};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
use postgres_client::connect_raw::StartupStream;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{NoTls, RawCancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tracing::{debug, error, info, warn};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::compute::tls::TlsError;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
@@ -236,8 +234,7 @@ impl AuthInfo {
&self,
ctx: &RequestContext,
compute: &mut ComputeConnection,
user_info: &ComputeUserInfo,
) -> Result<PostgresSettings, PostgresError> {
) -> Result<(), PostgresError> {
// client config with stubbed connect info.
// TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
// utilising pqproto.rs.
@@ -247,39 +244,10 @@ impl AuthInfo {
let tmp_config = self.enrich(tmp_config);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = tmp_config
.tls_and_authenticate(&mut compute.stream, NoTls)
.await?;
tmp_config.authenticate(&mut compute.stream).await?;
drop(pause);
let RawConnection {
stream: _,
parameters,
delayed_notice,
process_id,
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
compute.socket_addr,
RawCancelToken {
ssl_mode: compute.ssl_mode,
process_id,
secret_key,
},
compute.hostname.to_string(),
user_info.clone(),
);
Ok(PostgresSettings {
params: parameters,
cancel_closure,
delayed_notice,
})
Ok(())
}
}
@@ -343,21 +311,9 @@ impl ConnectInfo {
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
// TODO(conrad): we don't need to parse these.
// These are just immediately forwarded back to the client.
// We could instead stream them out instead of reading them into memory.
pub struct PostgresSettings {
/// PostgreSQL connection parameters.
pub params: std::collections::HashMap<String, String>,
/// Query cancellation token.
pub cancel_closure: CancelClosure,
/// Notices received from compute after authenticating
pub delayed_notice: Vec<NoticeResponseBody>,
}
pub struct ComputeConnection {
/// Socket connected to a compute node.
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub stream: StartupStream<tokio::net::TcpStream, RustlsStream>,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
pub hostname: Host,
@@ -390,6 +346,7 @@ impl ConnectInfo {
ctx.get_testodrome_id().unwrap_or_default(),
);
let stream = StartupStream::new(stream);
let connection = ComputeConnection {
stream,
socket_addr,

View File

@@ -1,12 +1,13 @@
use std::sync::Arc;
use futures::{FutureExt, TryFutureExt};
use postgres_client::RawCancelToken;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::CancellationHandler;
use crate::cancellation::{CancelClosure, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
@@ -16,7 +17,7 @@ use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ErrorSource, finish_client_init};
use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting};
use crate::util::run_until_cancelled;
pub async fn task_main(
@@ -226,21 +227,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
let pg_settings = auth_info
.authenticate(ctx, &mut node, &user_info)
auth_info
.authenticate(ctx, &mut node)
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
send_client_greeting(ctx, &config.greetings, &mut stream);
let session = cancellation_handler.get_key();
finish_client_init(
ctx,
&pg_settings,
*session.key(),
&mut stream,
&config.greetings,
);
let (process_id, secret_key) =
forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream)
.await?;
let stream = stream.flush_and_into_inner().await?;
let hostname = node.hostname.to_string();
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
@@ -249,7 +248,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.maintain_cancel_key(
session_id,
cancel,
&pg_settings.cancel_closure,
&CancelClosure {
socket_addr: node.socket_addr,
cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode,
process_id,
secret_key,
},
hostname,
user_info,
},
&config.connect_to_compute,
)
.await;
@@ -257,7 +265,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok(Some(ProxyPassthrough {
client: stream,
compute: node.stream,
compute: node.stream.into_framed().into_inner(),
aux: node.aux,
private_link_id: None,

View File

@@ -319,7 +319,7 @@ pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok(Some(ProxyPassthrough {
client,
compute: node.stream,
compute: node.stream.into_framed().into_inner(),
aux: node.aux,
private_link_id,

View File

@@ -313,6 +313,14 @@ impl WriteBuf {
self.0.set_position(0);
}
/// Shrinks the buffer if efficient to do so, and returns the remaining size.
pub fn occupied_len(&mut self) -> usize {
if self.should_shrink() {
self.shrink();
}
self.0.get_mut().len()
}
/// Write a raw message to the internal buffer.
///
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since

View File

@@ -9,18 +9,23 @@ use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::Arc;
use futures::TryStreamExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_client::RawCancelToken;
use postgres_client::connect_raw::StartupStream;
use postgres_protocol::message::backend::Message;
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, format_smolstr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tracing::Instrument;
use crate::cache::Cache;
use crate::cancellation::CancellationHandler;
use crate::compute::ComputeConnection;
use crate::cancellation::{CancelClosure, CancellationHandler};
use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
@@ -105,7 +110,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
// the compute was cached, and we connected, but the compute cache was actually stale
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
let pg_settings = loop {
loop {
attempt += 1;
// TODO: callback to pglb
@@ -127,9 +132,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
unreachable!("ensured above");
};
let res = auth_info.authenticate(ctx, &mut node, user_info).await;
let res = auth_info.authenticate(ctx, &mut node).await;
match res {
Ok(pg_settings) => break pg_settings,
Ok(()) => {
send_client_greeting(ctx, &config.greetings, client);
break;
}
Err(e) if attempt < 2 && e.should_retry_wake_compute() => {
tracing::warn!(error = ?e, "retrying wake compute");
@@ -141,11 +149,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
}
}
let auth::Backend::ControlPlane(_, user_info) = backend else {
unreachable!("ensured above");
};
let session = cancellation_handler.get_key();
finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings);
let (process_id, secret_key) =
forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?;
let hostname = node.hostname.to_string();
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = oneshot::channel();
@@ -154,7 +168,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.maintain_cancel_key(
session_id,
cancel,
&pg_settings.cancel_closure,
&CancelClosure {
socket_addr: node.socket_addr,
cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode,
process_id,
secret_key,
},
hostname,
user_info,
},
&config.connect_to_compute,
)
.await;
@@ -163,35 +186,18 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok((node, cancel_on_shutdown))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
pub(crate) fn finish_client_init(
/// Greet the client with any useful information.
pub(crate) fn send_client_greeting(
ctx: &RequestContext,
settings: &compute::PostgresSettings,
cancel_key_data: CancelKeyData,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
greetings: &String,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) {
// Forward all deferred notices to the client.
for notice in &settings.delayed_notice {
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
// Expose session_id to clients if we have a greeting message.
if !greetings.is_empty() {
let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id());
client.write_message(BeMessage::NoticeResponse(session_msg.as_str()));
}
// Forward all postgres connection params to the client.
for (name, value) in &settings.params {
client.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
});
}
// Forward recorded latencies for probing requests
if let Some(testodrome_id) = ctx.get_testodrome_id() {
client.write_message(BeMessage::ParameterStatus {
@@ -221,9 +227,63 @@ pub(crate) fn finish_client_init(
value: latency_measured.retry.as_micros().to_string().as_bytes(),
});
}
}
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
client.write_message(BeMessage::ReadyForQuery);
pub(crate) async fn forward_compute_params_to_client(
ctx: &RequestContext,
cancel_key_data: CancelKeyData,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
compute: &mut StartupStream<TcpStream, RustlsStream>,
) -> Result<(i32, i32), ClientRequestError> {
let mut process_id = 0;
let mut secret_key = 0;
let err = loop {
// if the client buffer is too large, let's write out some bytes now to save some space
client.write_if_full().await?;
let msg = match compute.try_next().await {
Ok(msg) => msg,
Err(e) => break postgres_client::Error::io(e),
};
match msg {
// Send our cancellation key data instead.
Some(Message::BackendKeyData(body)) => {
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
process_id = body.process_id();
secret_key = body.secret_key();
}
// Forward all postgres connection params to the client.
Some(Message::ParameterStatus(body)) => {
if let Ok(name) = body.name()
&& let Ok(value) = body.value()
{
client.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
});
}
}
// Forward all notices to the client.
Some(Message::NoticeResponse(notice)) => {
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
Some(Message::ReadyForQuery(_)) => {
client.write_message(BeMessage::ReadyForQuery);
return Ok((process_id, secret_key));
}
Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body),
Some(_) => break postgres_client::Error::unexpected_message(),
None => break postgres_client::Error::closed(),
}
};
Err(client
.throw_error(PostgresError::Postgres(err), Some(ctx))
.await)?
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]

View File

@@ -154,6 +154,15 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
message.write_message(&mut self.write);
}
/// Write the buffer to the socket until we have some more space again.
pub async fn write_if_full(&mut self) -> io::Result<()> {
while self.write.occupied_len() > 2048 {
self.stream.write_buf(&mut self.write).await?;
}
Ok(())
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.

View File

@@ -161,9 +161,9 @@ pub async fn handle_request(
FileStorage::create_new(&tli_dir_path, new_state.clone(), conf.no_sync).await?;
// now we have a ready timeline in a temp directory
validate_temp_timeline(conf, request.destination_ttid, &tli_dir_path).await?;
validate_temp_timeline(conf, request.destination_ttid, &tli_dir_path, None).await?;
global_timelines
.load_temp_timeline(request.destination_ttid, &tli_dir_path, true)
.load_temp_timeline(request.destination_ttid, &tli_dir_path, None)
.await?;
Ok(())

View File

@@ -193,7 +193,7 @@ pub async fn hcc_pull_timeline(
tenant_id: timeline.tenant_id,
timeline_id: timeline.timeline_id,
http_hosts: Vec::new(),
ignore_tombstone: None,
mconf: None,
};
for host in timeline.peers {
if host.0 == conf.my_id.0 {

View File

@@ -352,7 +352,7 @@ async fn timeline_exclude_handler(mut request: Request<Body>) -> Result<Response
// instead.
if data.mconf.contains(my_id) {
return Err(ApiError::Forbidden(format!(
"refused to switch into {}, node {} is member of it",
"refused to exclude timeline with {}, node {} is member of it",
data.mconf, my_id
)));
}

View File

@@ -13,8 +13,8 @@ use http_utils::error::ApiError;
use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo};
use remote_storage::GenericRemoteStorage;
use reqwest::Certificate;
use safekeeper_api::Term;
use safekeeper_api::models::{PullTimelineRequest, PullTimelineResponse, TimelineStatus};
use safekeeper_api::{Term, membership};
use safekeeper_client::mgmt_api;
use safekeeper_client::mgmt_api::Client;
use serde::Deserialize;
@@ -453,12 +453,40 @@ pub async fn handle_request(
global_timelines: Arc<GlobalTimelines>,
wait_for_peer_timeline_status: bool,
) -> Result<PullTimelineResponse, ApiError> {
if let Some(mconf) = &request.mconf {
let sk_id = global_timelines.get_sk_id();
if !mconf.contains(sk_id) {
return Err(ApiError::BadRequest(anyhow!(
"refused to pull timeline with {mconf}, node {sk_id} is not member of it",
)));
}
}
let existing_tli = global_timelines.get(TenantTimelineId::new(
request.tenant_id,
request.timeline_id,
));
if existing_tli.is_ok() {
info!("Timeline {} already exists", request.timeline_id);
if let Ok(timeline) = existing_tli {
let cur_generation = timeline
.read_shared_state()
.await
.sk
.state()
.mconf
.generation;
info!(
"Timeline {} already exists with generation {cur_generation}",
request.timeline_id,
);
if let Some(mconf) = request.mconf {
timeline
.membership_switch(mconf)
.await
.map_err(|e| ApiError::InternalServerError(anyhow::anyhow!(e)))?;
}
return Ok(PullTimelineResponse {
safekeeper_host: None,
});
@@ -495,6 +523,19 @@ pub async fn handle_request(
for (i, response) in responses.into_iter().enumerate() {
match response {
Ok(status) => {
if let Some(mconf) = &request.mconf {
if status.mconf.generation > mconf.generation {
// We probably raced with another timeline membership change with higher generation.
// Ignore this request.
return Err(ApiError::Conflict(format!(
"cannot pull timeline with generation {}: timeline {} already exists with generation {} on {}",
mconf.generation,
request.timeline_id,
status.mconf.generation,
http_hosts[i],
)));
}
}
statuses.push((status, i));
}
Err(e) => {
@@ -593,15 +634,13 @@ pub async fn handle_request(
assert!(status.tenant_id == request.tenant_id);
assert!(status.timeline_id == request.timeline_id);
let check_tombstone = !request.ignore_tombstone.unwrap_or_default();
match pull_timeline(
status,
safekeeper_host,
sk_auth_token,
http_client,
global_timelines,
check_tombstone,
request.mconf,
)
.await
{
@@ -611,6 +650,10 @@ pub async fn handle_request(
Some(TimelineError::AlreadyExists(_)) => Ok(PullTimelineResponse {
safekeeper_host: None,
}),
Some(TimelineError::Deleted(_)) => Err(ApiError::Conflict(format!(
"Timeline {}/{} deleted",
request.tenant_id, request.timeline_id
))),
Some(TimelineError::CreationInProgress(_)) => {
// We don't return success here because creation might still fail.
Err(ApiError::Conflict("Creation in progress".to_owned()))
@@ -627,7 +670,7 @@ async fn pull_timeline(
sk_auth_token: Option<SecretString>,
http_client: reqwest::Client,
global_timelines: Arc<GlobalTimelines>,
check_tombstone: bool,
mconf: Option<membership::Configuration>,
) -> Result<PullTimelineResponse> {
let ttid = TenantTimelineId::new(status.tenant_id, status.timeline_id);
info!(
@@ -689,8 +732,11 @@ async fn pull_timeline(
// fsync temp timeline directory to remember its contents.
fsync_async_opt(&tli_dir_path, !conf.no_sync).await?;
let generation = mconf.as_ref().map(|c| c.generation);
// Let's create timeline from temp directory and verify that it's correct
let (commit_lsn, flush_lsn) = validate_temp_timeline(conf, ttid, &tli_dir_path).await?;
let (commit_lsn, flush_lsn) =
validate_temp_timeline(conf, ttid, &tli_dir_path, generation).await?;
info!(
"finished downloading timeline {}, commit_lsn={}, flush_lsn={}",
ttid, commit_lsn, flush_lsn
@@ -698,10 +744,20 @@ async fn pull_timeline(
assert!(status.commit_lsn <= status.flush_lsn);
// Finally, load the timeline.
let _tli = global_timelines
.load_temp_timeline(ttid, &tli_dir_path, check_tombstone)
let timeline = global_timelines
.load_temp_timeline(ttid, &tli_dir_path, generation)
.await?;
if let Some(mconf) = mconf {
// Switch to provided mconf to guarantee that the timeline will not
// be deleted by request with older generation.
// The generation might already be higer than the one in mconf, e.g.
// if another membership_switch request was executed between `load_temp_timeline`
// and `membership_switch`, but that's totaly fine. `membership_switch` will
// ignore switch to older generation.
timeline.membership_switch(mconf).await?;
}
Ok(PullTimelineResponse {
safekeeper_host: Some(host),
})

View File

@@ -1026,6 +1026,13 @@ where
self.state.finish_change(&state).await?;
}
if msg.mconf.generation > self.state.mconf.generation && !msg.mconf.contains(self.node_id) {
bail!(
"refused to switch into {}, node {} is not a member of it",
msg.mconf,
self.node_id,
);
}
// Switch into conf given by proposer conf if it is higher.
self.state.membership_switch(msg.mconf.clone()).await?;

View File

@@ -594,7 +594,7 @@ impl Timeline {
/// Cancel the timeline, requesting background activity to stop. Closing
/// the `self.gate` waits for that.
pub async fn cancel(&self) {
pub fn cancel(&self) {
info!("timeline {} shutting down", self.ttid);
self.cancel.cancel();
}
@@ -914,6 +914,13 @@ impl Timeline {
to: Configuration,
) -> Result<TimelineMembershipSwitchResponse> {
let mut state = self.write_shared_state().await;
// Ensure we don't race with exclude/delete requests by checking the cancellation
// token under the write_shared_state lock.
// Exclude/delete cancel the timeline under the shared state lock,
// so the timeline cannot be deleted in the middle of the membership switch.
if self.is_cancelled() {
bail!(TimelineError::Cancelled(self.ttid));
}
state.sk.membership_switch(to).await
}

View File

@@ -10,13 +10,13 @@ use std::time::{Duration, Instant};
use anyhow::{Context, Result, bail};
use camino::Utf8PathBuf;
use camino_tempfile::Utf8TempDir;
use safekeeper_api::membership::Configuration;
use safekeeper_api::membership::{Configuration, SafekeeperGeneration};
use safekeeper_api::models::{SafekeeperUtilization, TimelineDeleteResult};
use safekeeper_api::{ServerInfo, membership};
use tokio::fs;
use tracing::*;
use utils::crashsafe::{durable_rename, fsync_async_opt};
use utils::id::{TenantId, TenantTimelineId, TimelineId};
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use utils::lsn::Lsn;
use crate::defaults::DEFAULT_EVICTION_CONCURRENCY;
@@ -40,10 +40,17 @@ enum GlobalMapTimeline {
struct GlobalTimelinesState {
timelines: HashMap<TenantTimelineId, GlobalMapTimeline>,
// A tombstone indicates this timeline used to exist has been deleted. These are used to prevent
// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as
// this map is dropped on restart.
tombstones: HashMap<TenantTimelineId, Instant>,
/// A tombstone indicates this timeline used to exist has been deleted. These are used to prevent
/// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as
/// this map is dropped on restart.
/// The timeline might also be locally deleted (excluded) via safekeeper migration algorithm. In that case,
/// the tombsone contains the corresponding safekeeper generation. The pull_timeline requests with
/// higher generation ignore such tombstones and can recreate the timeline.
timeline_tombstones: HashMap<TenantTimelineId, TimelineTombstone>,
/// A tombstone indicates that the tenant used to exist has been deleted.
/// These are created only by tenant_delete requests. They are always valid regardless of the
/// request generation.
/// This is only soft-enforced, as this map is dropped on restart.
tenant_tombstones: HashMap<TenantId, Instant>,
conf: Arc<SafeKeeperConf>,
@@ -79,7 +86,7 @@ impl GlobalTimelinesState {
Err(TimelineError::CreationInProgress(*ttid))
}
None => {
if self.has_tombstone(ttid) {
if self.has_tombstone(ttid, None) {
Err(TimelineError::Deleted(*ttid))
} else {
Err(TimelineError::NotFound(*ttid))
@@ -88,20 +95,46 @@ impl GlobalTimelinesState {
}
}
fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool {
self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id)
fn has_timeline_tombstone(
&self,
ttid: &TenantTimelineId,
generation: Option<SafekeeperGeneration>,
) -> bool {
if let Some(generation) = generation {
self.timeline_tombstones
.get(ttid)
.is_some_and(|t| t.is_valid(generation))
} else {
self.timeline_tombstones.contains_key(ttid)
}
}
/// Removes all blocking tombstones for the given timeline ID.
fn has_tenant_tombstone(&self, tenant_id: &TenantId) -> bool {
self.tenant_tombstones.contains_key(tenant_id)
}
/// Check if the state has a tenant or a timeline tombstone.
/// If `generation` is provided, check only for timeline tombsotnes with same or higher generation.
/// If `generation` is `None`, check for any timeline tombstone.
/// Tenant tombstones are checked regardless of the generation.
fn has_tombstone(
&self,
ttid: &TenantTimelineId,
generation: Option<SafekeeperGeneration>,
) -> bool {
self.has_timeline_tombstone(ttid, generation) || self.has_tenant_tombstone(&ttid.tenant_id)
}
/// Removes timeline tombstone for the given timeline ID.
/// Returns `true` if there have been actual changes.
fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool {
self.tombstones.remove(ttid).is_some()
|| self.tenant_tombstones.remove(&ttid.tenant_id).is_some()
fn remove_timeline_tombstone(&mut self, ttid: &TenantTimelineId) -> bool {
self.timeline_tombstones.remove(ttid).is_some()
}
fn delete(&mut self, ttid: TenantTimelineId) {
fn delete(&mut self, ttid: TenantTimelineId, generation: Option<SafekeeperGeneration>) {
self.timelines.remove(&ttid);
self.tombstones.insert(ttid, Instant::now());
self.timeline_tombstones
.insert(ttid, TimelineTombstone::new(generation));
}
fn add_tenant_tombstone(&mut self, tenant_id: TenantId) {
@@ -120,7 +153,7 @@ impl GlobalTimelines {
Self {
state: Mutex::new(GlobalTimelinesState {
timelines: HashMap::new(),
tombstones: HashMap::new(),
timeline_tombstones: HashMap::new(),
tenant_tombstones: HashMap::new(),
conf,
broker_active_set: Arc::new(TimelinesSet::default()),
@@ -261,6 +294,8 @@ impl GlobalTimelines {
start_lsn: Lsn,
commit_lsn: Lsn,
) -> Result<Arc<Timeline>> {
let generation = Some(mconf.generation);
let (conf, _, _, _) = {
let state = self.state.lock().unwrap();
if let Ok(timeline) = state.get(&ttid) {
@@ -268,8 +303,8 @@ impl GlobalTimelines {
return Ok(timeline);
}
if state.has_tombstone(&ttid) {
anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate");
if state.has_tombstone(&ttid, generation) {
anyhow::bail!(TimelineError::Deleted(ttid));
}
state.get_dependencies()
@@ -284,7 +319,9 @@ impl GlobalTimelines {
// immediately initialize first WAL segment as well.
let state = TimelinePersistentState::new(&ttid, mconf, server_info, start_lsn, commit_lsn)?;
control_file::FileStorage::create_new(&tmp_dir_path, state, conf.no_sync).await?;
let timeline = self.load_temp_timeline(ttid, &tmp_dir_path, true).await?;
let timeline = self
.load_temp_timeline(ttid, &tmp_dir_path, generation)
.await?;
Ok(timeline)
}
@@ -303,7 +340,7 @@ impl GlobalTimelines {
&self,
ttid: TenantTimelineId,
tmp_path: &Utf8PathBuf,
check_tombstone: bool,
generation: Option<SafekeeperGeneration>,
) -> Result<Arc<Timeline>> {
// Check for existence and mark that we're creating it.
let (conf, broker_active_set, partial_backup_rate_limiter, wal_backup) = {
@@ -317,18 +354,18 @@ impl GlobalTimelines {
}
_ => {}
}
if check_tombstone {
if state.has_tombstone(&ttid) {
anyhow::bail!("timeline {ttid} is deleted, refusing to recreate");
}
} else {
// We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust
// that the human doing this manual intervention knows what they are doing, and remove its tombstone.
// It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed.
if state.remove_tombstone(&ttid) {
warn!("un-deleted timeline {ttid}");
}
if state.has_tombstone(&ttid, generation) {
// If the timeline is deleted, we refuse to recreate it.
// This is a safeguard against accidentally overwriting a timeline that was deleted
// by concurrent request.
anyhow::bail!(TimelineError::Deleted(ttid));
}
// We might have an outdated tombstone with the older generation.
// Remove it unconditionally.
state.remove_timeline_tombstone(&ttid);
state
.timelines
.insert(ttid, GlobalMapTimeline::CreationInProgress);
@@ -503,11 +540,16 @@ impl GlobalTimelines {
ttid: &TenantTimelineId,
action: DeleteOrExclude,
) -> Result<TimelineDeleteResult, DeleteOrExcludeError> {
let generation = match &action {
DeleteOrExclude::Delete | DeleteOrExclude::DeleteLocal => None,
DeleteOrExclude::Exclude(mconf) => Some(mconf.generation),
};
let tli_res = {
let state = self.state.lock().unwrap();
// Do NOT check tenant tombstones here: those were set earlier
if state.tombstones.contains_key(ttid) {
if state.has_timeline_tombstone(ttid, generation) {
// Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do.
info!("Timeline {ttid} was already deleted");
return Ok(TimelineDeleteResult { dir_existed: false });
@@ -528,6 +570,11 @@ impl GlobalTimelines {
// We would like to avoid holding the lock while waiting for the
// gate to finish as this is deadlock prone, so for actual
// deletion will take it second time.
//
// Canceling the timeline will block membership switch requests,
// ensuring that the timeline generation will not increase
// after this point, and we will not remove a timeline with a generation
// higher than the requested one.
if let DeleteOrExclude::Exclude(ref mconf) = action {
let shared_state = timeline.read_shared_state().await;
if shared_state.sk.state().mconf.generation > mconf.generation {
@@ -536,9 +583,9 @@ impl GlobalTimelines {
current: shared_state.sk.state().mconf.clone(),
});
}
timeline.cancel().await;
timeline.cancel();
} else {
timeline.cancel().await;
timeline.cancel();
}
timeline.close().await;
@@ -565,7 +612,7 @@ impl GlobalTimelines {
// Finalize deletion, by dropping Timeline objects and storing smaller tombstones. The tombstones
// are used to prevent still-running computes from re-creating the same timeline when they send data,
// and to speed up repeated deletion calls by avoiding re-listing objects.
self.state.lock().unwrap().delete(*ttid);
self.state.lock().unwrap().delete(*ttid, generation);
result
}
@@ -627,12 +674,16 @@ impl GlobalTimelines {
// may recreate a deleted timeline.
let now = Instant::now();
state
.tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
.timeline_tombstones
.retain(|_, v| now.duration_since(v.timestamp) < *tombstone_ttl);
state
.tenant_tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
}
pub fn get_sk_id(&self) -> NodeId {
self.state.lock().unwrap().conf.my_id
}
}
/// Action for delete_or_exclude.
@@ -673,6 +724,7 @@ pub async fn validate_temp_timeline(
conf: &SafeKeeperConf,
ttid: TenantTimelineId,
path: &Utf8PathBuf,
generation: Option<SafekeeperGeneration>,
) -> Result<(Lsn, Lsn)> {
let control_path = path.join("safekeeper.control");
@@ -681,6 +733,15 @@ pub async fn validate_temp_timeline(
bail!("wal_seg_size is not set");
}
if let Some(generation) = generation {
if control_store.mconf.generation > generation {
bail!(
"tmp timeline generation {} is higher than expected {generation}",
control_store.mconf.generation
);
}
}
let wal_store = wal_storage::PhysicalStorage::new(&ttid, path, &control_store, conf.no_sync)?;
let commit_lsn = control_store.commit_lsn;
@@ -688,3 +749,28 @@ pub async fn validate_temp_timeline(
Ok((commit_lsn, flush_lsn))
}
/// A tombstone for a deleted timeline.
/// The generation is passed with "exclude" request and stored in the tombstone.
/// We ignore the tombstone if the request generation is higher than
/// the tombstone generation.
/// If the tombstone doesn't have a generation, it's considered permanent,
/// e.g. after "delete" request.
struct TimelineTombstone {
timestamp: Instant,
generation: Option<SafekeeperGeneration>,
}
impl TimelineTombstone {
fn new(generation: Option<SafekeeperGeneration>) -> Self {
TimelineTombstone {
timestamp: Instant::now(),
generation,
}
}
/// Check if the timeline is still valid for the given generation.
fn is_valid(&self, generation: SafekeeperGeneration) -> bool {
self.generation.is_none_or(|g| g >= generation)
}
}

View File

@@ -364,7 +364,12 @@ impl SafekeeperReconcilerInner {
http_hosts,
tenant_id: req.tenant_id,
timeline_id,
ignore_tombstone: Some(false),
// TODO(diko): get mconf from "timelines" table and pass it here.
// Now we use pull_timeline reconciliation only for the timeline creation,
// so it's not critical right now.
// It could be fixed together with other reconciliation issues:
// https://github.com/neondatabase/neon/issues/12189
mconf: None,
};
success = self
.reconcile_inner(

View File

@@ -991,6 +991,7 @@ impl Service {
timeline_id: TimelineId,
to_safekeepers: &[Safekeeper],
from_safekeepers: &[Safekeeper],
mconf: membership::Configuration,
) -> Result<(), ApiError> {
let http_hosts = from_safekeepers
.iter()
@@ -1009,14 +1010,11 @@ impl Service {
.collect::<Vec<_>>()
);
// TODO(diko): need to pass mconf/generation with the request
// to properly handle tombstones. Ignore tombstones for now.
// Worst case: we leave a timeline on a safekeeper which is not in the current set.
let req = PullTimelineRequest {
tenant_id,
timeline_id,
http_hosts,
ignore_tombstone: Some(true),
mconf: Some(mconf),
};
const SK_PULL_TIMELINE_RECONCILE_TIMEOUT: Duration = Duration::from_secs(30);
@@ -1336,6 +1334,7 @@ impl Service {
timeline_id,
&pull_to_safekeepers,
&cur_safekeepers,
joint_config.clone(),
)
.await?;

View File

@@ -1542,6 +1542,17 @@ class NeonEnv:
raise RuntimeError(f"Pageserver with ID {id} not found")
def get_safekeeper(self, id: int) -> Safekeeper:
"""
Look up a safekeeper by its ID.
"""
for sk in self.safekeepers:
if sk.id == id:
return sk
raise RuntimeError(f"Safekeeper with ID {id} not found")
def get_tenant_pageserver(self, tenant_id: TenantId | TenantShardId):
"""
Get the NeonPageserver where this tenant shard is currently attached, according
@@ -5403,15 +5414,24 @@ class Safekeeper(LogUtils):
return timeline_status.commit_lsn
def pull_timeline(
self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId
self,
srcs: list[Safekeeper],
tenant_id: TenantId,
timeline_id: TimelineId,
mconf: MembershipConfiguration | None = None,
) -> dict[str, Any]:
"""
pull_timeline from srcs to self.
"""
src_https = [f"http://localhost:{sk.port.http}" for sk in srcs]
res = self.http_client().pull_timeline(
{"tenant_id": str(tenant_id), "timeline_id": str(timeline_id), "http_hosts": src_https}
)
body: dict[str, Any] = {
"tenant_id": str(tenant_id),
"timeline_id": str(timeline_id),
"http_hosts": src_https,
}
if mconf is not None:
body["mconf"] = mconf.__dict__
res = self.http_client().pull_timeline(body)
src_ids = [sk.id for sk in srcs]
log.info(f"finished pulling timeline from {src_ids} to {self.id}")
return res

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import re
from typing import TYPE_CHECKING
import pytest
@@ -12,7 +13,7 @@ if TYPE_CHECKING:
# TODO(diko): pageserver spams with various errors during safekeeper migration.
# Fix the code so it handles the migration better.
ALLOWED_PAGESERVER_ERRORS = [
PAGESERVER_ALLOWED_ERRORS = [
".*Timeline .* was cancelled and cannot be used anymore.*",
".*Timeline .* has been deleted.*",
".*Timeline .* was not found in global map.*",
@@ -35,7 +36,7 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder):
"timeline_safekeeper_count": 1,
}
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS)
env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS)
ep = env.endpoints.create("main", tenant_id=env.initial_tenant)
@@ -136,7 +137,7 @@ def test_safekeeper_migration_common_set_failpoints(neon_env_builder: NeonEnvBui
"timeline_safekeeper_count": 3,
}
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS)
env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS)
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert len(mconf["sk_set"]) == 3
@@ -196,3 +197,92 @@ def test_safekeeper_migration_common_set_failpoints(neon_env_builder: NeonEnvBui
assert (
f"timeline {env.initial_tenant}/{env.initial_timeline} deleted" in exc.value.response.text
)
def test_sk_generation_aware_tombstones(neon_env_builder: NeonEnvBuilder):
"""
Test that safekeeper respects generations:
1. Check that migration back and forth between two safekeepers works.
2. Check that sk refuses to execute requests with stale generation.
"""
neon_env_builder.num_safekeepers = 3
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
"timeline_safekeeper_count": 1,
}
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS)
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["new_sk_set"] is None
assert len(mconf["sk_set"]) == 1
cur_sk = mconf["sk_set"][0]
second_sk, third_sk = [sk.id for sk in env.safekeepers if sk.id != cur_sk]
cur_gen = 1
# Pull the timeline manually to third_sk, so the timeline exists there with stale generation.
# This is needed for the test later.
env.get_safekeeper(third_sk).pull_timeline(
[env.get_safekeeper(cur_sk)], env.initial_tenant, env.initial_timeline
)
def expect_deleted(sk_id: int):
with pytest.raises(requests.exceptions.HTTPError, match="Not Found") as exc:
env.get_safekeeper(sk_id).http_client().timeline_status(
env.initial_tenant, env.initial_timeline
)
assert exc.value.response.status_code == 404
assert re.match(r".*timeline .* deleted.*", exc.value.response.text)
def get_mconf(sk_id: int):
status = (
env.get_safekeeper(sk_id)
.http_client()
.timeline_status(env.initial_tenant, env.initial_timeline)
)
assert status.mconf is not None
return status.mconf
def migrate():
nonlocal cur_sk, second_sk, cur_gen
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, [second_sk]
)
cur_sk, second_sk = second_sk, cur_sk
cur_gen += 2
# Migrate the timeline back and forth between cur_sk and second_sk.
for _i in range(3):
migrate()
# Timeline should exist on cur_sk.
assert get_mconf(cur_sk).generation == cur_gen
# Timeline should be deleted on second_sk.
expect_deleted(second_sk)
# Remember current mconf.
mconf = get_mconf(cur_sk)
# Migrate the timeline one more time.
# It increases the generation by 2.
migrate()
# Check that sk refuses to execute the exclude request with the old mconf.
with pytest.raises(requests.exceptions.HTTPError, match="Conflict") as exc:
env.get_safekeeper(cur_sk).http_client().timeline_exclude(
env.initial_tenant, env.initial_timeline, mconf
)
assert re.match(r".*refused to switch into excluding mconf.*", exc.value.response.text)
# We shouldn't have deleted the timeline.
assert get_mconf(cur_sk).generation == cur_gen
# Check that sk refuses to execute the pull_timeline request with the old mconf.
# Note: we try to pull from third_sk, which has a timeline with stale generation.
# Thus, we bypass some preliminary generation checks and actually test tombstones.
with pytest.raises(requests.exceptions.HTTPError, match="Conflict") as exc:
env.get_safekeeper(second_sk).pull_timeline(
[env.get_safekeeper(third_sk)], env.initial_tenant, env.initial_timeline, mconf
)
assert re.match(r".*Timeline .* deleted.*", exc.value.response.text)
# The timeline should remain deleted.
expect_deleted(second_sk)