Compare commits

..

4 Commits

Author SHA1 Message Date
Arpad Müller
521a9590cf Merge branch 'arpad/rustls_everywhere' into arpad/musl_libc_v2 2024-05-25 02:30:55 +02:00
Arpad Müller
3edf932e80 Use tokio-epoll-uring patched for statx-sys 2024-05-25 02:30:27 +02:00
Arpad Müller
b3df345984 Deny postgres-native-tls 2024-05-25 02:09:34 +02:00
Arpad Müller
cdcfb504b6 Drop postgres-native-tls in favour of tokio-postgres-rustls 2024-05-25 02:07:20 +02:00
30 changed files with 209 additions and 340 deletions

38
Cargo.lock generated
View File

@@ -4105,17 +4105,6 @@ dependencies = [
"tokio-postgres",
]
[[package]]
name = "postgres-native-tls"
version = "0.5.0"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#20031d7a9ee1addeae6e0968e3899ae6bf01cee2"
dependencies = [
"native-tls",
"tokio",
"tokio-native-tls",
"tokio-postgres",
]
[[package]]
name = "postgres-protocol"
version = "0.6.4"
@@ -4423,7 +4412,6 @@ dependencies = [
"md5",
"measured",
"metrics",
"native-tls",
"once_cell",
"opentelemetry",
"parking_lot 0.12.1",
@@ -4431,7 +4419,6 @@ dependencies = [
"parquet_derive",
"pbkdf2",
"pin-project-lite",
"postgres-native-tls",
"postgres-protocol",
"postgres_backend",
"pq_proto",
@@ -4479,7 +4466,7 @@ dependencies = [
"utils",
"uuid",
"walkdir",
"webpki-roots 0.25.2",
"webpki-roots 0.26.1",
"workspace_hack",
"x509-parser",
]
@@ -5232,20 +5219,20 @@ dependencies = [
"hex",
"histogram",
"itertools",
"native-tls",
"pageserver",
"pageserver_api",
"postgres-native-tls",
"postgres_ffi",
"rand 0.8.5",
"remote_storage",
"reqwest 0.12.4",
"rustls 0.22.4",
"serde",
"serde_json",
"serde_with",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-stream",
"tokio-util",
@@ -5253,6 +5240,7 @@ dependencies = [
"tracing-appender",
"tracing-subscriber",
"utils",
"webpki-roots 0.26.1",
"workspace_hack",
]
@@ -5843,6 +5831,15 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "statx-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69c325f46f705b7a66fb87f0ebb999524a7363f30f05d373277b4ef7f409fe87"
dependencies = [
"libc",
]
[[package]]
name = "storage_broker"
version = "0.1.0"
@@ -6266,7 +6263,7 @@ dependencies = [
[[package]]
name = "tokio-epoll-uring"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#342ddd197a060a8354e8f11f4d12994419fff939"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=arpad/statx_sys#ca8446b8edb5e0aef88520f2fc209a13a834fd25"
dependencies = [
"futures",
"nix 0.26.4",
@@ -6797,11 +6794,12 @@ dependencies = [
[[package]]
name = "uring-common"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#342ddd197a060a8354e8f11f4d12994419fff939"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=arpad/statx_sys#ca8446b8edb5e0aef88520f2fc209a13a834fd25"
dependencies = [
"bytes",
"io-uring",
"libc",
"statx-sys",
]
[[package]]
@@ -7629,9 +7627,9 @@ dependencies = [
[[package]]
name = "zeroize"
version = "1.6.0"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9"
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
dependencies = [
"zeroize_derive",
]

View File

@@ -171,7 +171,7 @@ thiserror = "1.0"
tikv-jemallocator = "0.5"
tikv-jemalloc-ctl = "0.5"
tokio = { version = "1.17", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "arpad/statx_sys" }
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.11.0"
tokio-rustls = "0.25"
@@ -191,7 +191,7 @@ url = "2.2"
urlencoding = "2.1"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
webpki-roots = "0.25"
webpki-roots = "0.26"
x509-parser = "0.15"
## TODO replace this with tracing
@@ -200,7 +200,6 @@ log = "0.4"
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
@@ -241,8 +240,7 @@ tonic-build = "0.9"
[patch.crates-io]
# This is only needed for proxy's tests.
# TODO: we should probably fork `tokio-postgres-rustls` instead.
# Needed to get `tokio-postgres-rustls` to depend on our fork.
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
# bug fixes for UUID

View File

@@ -99,6 +99,10 @@ name = "async-executor"
[[bans.deny]]
name = "smol"
[[bans.deny]]
# We want to use rustls instead of native-tls.
name = "postgres-native-tls"
# This section is considered when running `cargo deny check sources`.
# More documentation about the 'sources' section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html

View File

@@ -9,33 +9,6 @@ use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::*;
/// Declare a failpoint that can use the `pause` failpoint action.
/// We don't want to block the executor thread, hence, spawn_blocking + await.
#[macro_export]
macro_rules! pausable_failpoint {
($name:literal) => {
if cfg!(feature = "testing") {
tokio::task::spawn_blocking({
let current = tracing::Span::current();
move || {
let _entered = current.entered();
tracing::info!("at failpoint {}", $name);
fail::fail_point!($name);
}
})
.await
.expect("spawn_blocking");
}
};
($name:literal, $cond:expr) => {
if cfg!(feature = "testing") {
if $cond {
pausable_failpoint!($name)
}
}
};
}
/// use with fail::cfg("$name", "return(2000)")
///
/// The effect is similar to a "sleep(2000)" action, i.e. we sleep for the

View File

@@ -42,7 +42,6 @@ use utils::completion;
use utils::crashsafe::path_with_suffix_extension;
use utils::failpoint_support;
use utils::fs_ext;
use utils::pausable_failpoint;
use utils::sync::gate::Gate;
use utils::sync::gate::GateGuard;
use utils::timeout::timeout_cancellable;
@@ -123,6 +122,32 @@ use utils::{
lsn::{Lsn, RecordLsn},
};
/// Declare a failpoint that can use the `pause` failpoint action.
/// We don't want to block the executor thread, hence, spawn_blocking + await.
macro_rules! pausable_failpoint {
($name:literal) => {
if cfg!(feature = "testing") {
tokio::task::spawn_blocking({
let current = tracing::Span::current();
move || {
let _entered = current.entered();
tracing::info!("at failpoint {}", $name);
fail::fail_point!($name);
}
})
.await
.expect("spawn_blocking");
}
};
($name:literal, $cond:expr) => {
if cfg!(feature = "testing") {
if $cond {
pausable_failpoint!($name)
}
}
};
}
pub mod blob_io;
pub mod block_io;
pub mod vectored_blob_io;

View File

@@ -8,7 +8,7 @@ use tokio::sync::OwnedMutexGuard;
use tokio_util::sync::CancellationToken;
use tracing::{error, instrument, Instrument};
use utils::{backoff, completion, crashsafe, fs_ext, id::TimelineId, pausable_failpoint};
use utils::{backoff, completion, crashsafe, fs_ext, id::TimelineId};
use crate::{
config::PageServerConf,

View File

@@ -197,7 +197,6 @@ pub(crate) use upload::upload_initdb_dir;
use utils::backoff::{
self, exponential_backoff, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS,
};
use utils::pausable_failpoint;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU32, Ordering};

View File

@@ -9,7 +9,7 @@ use std::time::SystemTime;
use tokio::fs::{self, File};
use tokio::io::AsyncSeekExt;
use tokio_util::sync::CancellationToken;
use utils::{backoff, pausable_failpoint};
use utils::backoff;
use super::Generation;
use crate::tenant::remote_timeline_client::{

View File

@@ -17,7 +17,7 @@ use crate::tenant::{Tenant, TenantState};
use rand::Rng;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::{backoff, completion, pausable_failpoint};
use utils::{backoff, completion};
static CONCURRENT_BACKGROUND_TASKS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {

View File

@@ -41,7 +41,7 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::{
bin_ser::BeSer,
fs_ext, pausable_failpoint,
fs_ext,
sync::gate::{Gate, GateGuard},
vec_map::VecMap,
};

View File

@@ -7,7 +7,7 @@ use anyhow::Context;
use pageserver_api::{models::TimelineState, shard::TenantShardId};
use tokio::sync::OwnedMutexGuard;
use tracing::{error, info, instrument, Instrument};
use utils::{crashsafe, fs_ext, id::TimelineId, pausable_failpoint};
use utils::{crashsafe, fs_ext, id::TimelineId};
use crate::{
config::PageServerConf,

View File

@@ -125,6 +125,13 @@ typedef struct
* - WL_EXIT_ON_PM_DEATH.
*/
WaitEventSet *wes_read;
/*---
* WaitEventSet containing:
* - WL_SOCKET_WRITABLE on 'conn'
* - WL_LATCH_SET on MyLatch, and
* - WL_EXIT_ON_PM_DEATH.
*/
WaitEventSet *wes_write;
} PageServer;
static PageServer page_servers[MAX_SHARDS];
@@ -329,6 +336,11 @@ CLEANUP_AND_DISCONNECT(PageServer *shard)
FreeWaitEventSet(shard->wes_read);
shard->wes_read = NULL;
}
if (shard->wes_write)
{
FreeWaitEventSet(shard->wes_write);
shard->wes_write = NULL;
}
if (shard->conn)
{
PQfinish(shard->conn);
@@ -424,6 +436,22 @@ pageserver_connect(shardno_t shard_no, int elevel)
return false;
}
shard->wes_read = CreateWaitEventSet(TopMemoryContext, 3);
AddWaitEventToSet(shard->wes_read, WL_LATCH_SET, PGINVALID_SOCKET,
MyLatch, NULL);
AddWaitEventToSet(shard->wes_read, WL_EXIT_ON_PM_DEATH, PGINVALID_SOCKET,
NULL, NULL);
AddWaitEventToSet(shard->wes_read, WL_SOCKET_READABLE, PQsocket(shard->conn), NULL, NULL);
shard->wes_write = CreateWaitEventSet(TopMemoryContext, 3);
AddWaitEventToSet(shard->wes_write, WL_LATCH_SET, PGINVALID_SOCKET,
MyLatch, NULL);
AddWaitEventToSet(shard->wes_write, WL_EXIT_ON_PM_DEATH, PGINVALID_SOCKET,
NULL, NULL);
AddWaitEventToSet(shard->wes_write, WL_SOCKET_READABLE | WL_SOCKET_WRITEABLE,
PQsocket(shard->conn),
NULL, NULL);
shard->state = PS_Connecting_Startup;
/* fallthrough */
}
@@ -432,12 +460,13 @@ pageserver_connect(shardno_t shard_no, int elevel)
char *pagestream_query;
int ps_send_query_ret;
bool connected = false;
int poll_result = PGRES_POLLING_WRITING;
neon_shard_log(shard_no, DEBUG5, "Connection state: Connecting_Startup");
do
{
WaitEvent event;
int poll_result = PQconnectPoll(shard->conn);
switch (poll_result)
{
@@ -468,45 +497,25 @@ pageserver_connect(shardno_t shard_no, int elevel)
}
case PGRES_POLLING_READING:
/* Sleep until there's something to do */
while (true)
{
int rc = WaitLatchOrSocket(MyLatch,
WL_EXIT_ON_PM_DEATH | WL_LATCH_SET | WL_SOCKET_READABLE,
PQsocket(shard->conn),
0,
PG_WAIT_EXTENSION);
elog(DEBUG5, "PGRES_POLLING_READING=>%d", rc);
if (rc & WL_LATCH_SET)
{
ResetLatch(MyLatch);
/* query cancellation, backend shutdown */
CHECK_FOR_INTERRUPTS();
}
if (rc & WL_SOCKET_READABLE)
break;
}
(void) WaitEventSetWait(shard->wes_read, -1L, &event, 1,
PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
/* query cancellation, backend shutdown */
CHECK_FOR_INTERRUPTS();
/* PQconnectPoll() handles the socket polling state updates */
break;
case PGRES_POLLING_WRITING:
/* Sleep until there's something to do */
while (true)
{
int rc = WaitLatchOrSocket(MyLatch,
WL_EXIT_ON_PM_DEATH | WL_LATCH_SET | WL_SOCKET_WRITEABLE,
PQsocket(shard->conn),
0,
PG_WAIT_EXTENSION);
elog(DEBUG5, "PGRES_POLLING_WRITING=>%d", rc);
if (rc & WL_LATCH_SET)
{
ResetLatch(MyLatch);
/* query cancellation, backend shutdown */
CHECK_FOR_INTERRUPTS();
}
if (rc & WL_SOCKET_WRITEABLE)
break;
}
(void) WaitEventSetWait(shard->wes_write, -1L, &event, 1,
PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
/* query cancellation, backend shutdown */
CHECK_FOR_INTERRUPTS();
/* PQconnectPoll() handles the socket polling state updates */
break;
@@ -515,22 +524,12 @@ pageserver_connect(shardno_t shard_no, int elevel)
connected = true;
break;
}
poll_result = PQconnectPoll(shard->conn);
elog(DEBUG5, "PQconnectPoll=>%d", poll_result);
}
while (!connected);
/* No more polling needed; connection succeeded */
shard->last_connect_time = GetCurrentTimestamp();
shard->wes_read = CreateWaitEventSet(TopMemoryContext, 3);
AddWaitEventToSet(shard->wes_read, WL_LATCH_SET, PGINVALID_SOCKET,
MyLatch, NULL);
AddWaitEventToSet(shard->wes_read, WL_EXIT_ON_PM_DEATH, PGINVALID_SOCKET,
NULL, NULL);
AddWaitEventToSet(shard->wes_read, WL_SOCKET_READABLE, PQsocket(shard->conn), NULL, NULL);
switch (neon_protocol_version)
{
case 2:

View File

@@ -584,9 +584,9 @@ prefetch_read(PrefetchRequest *slot)
slot->response != NULL ||
slot->my_ring_index != MyPState->ring_receive)
neon_shard_log(slot->shard_no, ERROR,
"Incorrect prefetch read: status=%d response=%p my=%lu receive=%lu",
slot->status, slot->response,
(long)slot->my_ring_index, (long)MyPState->ring_receive);
"Incorrect prefetch read: status=%d response=%llx my=%llu receive=%llu",
slot->status, (size_t) (void *) slot->response,
slot->my_ring_index, MyPState->ring_receive);
old = MemoryContextSwitchTo(MyPState->errctx);
response = (NeonResponse *) page_server->receive(slot->shard_no);
@@ -606,8 +606,8 @@ prefetch_read(PrefetchRequest *slot)
else
{
neon_shard_log(slot->shard_no, WARNING,
"No response from reading prefetch entry %lu: %u/%u/%u.%u block %u. This can be caused by a concurrent disconnect",
(long)slot->my_ring_index,
"No response from reading prefetch entry %llu: %u/%u/%u.%u block %u. This can be caused by a concurrent disconnect",
slot->my_ring_index,
RelFileInfoFmt(BufTagGetNRelFileInfo(slot->buftag)),
slot->buftag.forkNum, slot->buftag.blockNum);
return false;

View File

@@ -82,6 +82,7 @@ thiserror.workspace = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true
tokio-rustls.workspace = true
tokio-util.workspace = true
tokio = { workspace = true, features = ["signal"] }
@@ -96,8 +97,6 @@ utils.workspace = true
uuid.workspace = true
webpki-roots.workspace = true
x509-parser.workspace = true
native-tls.workspace = true
postgres-native-tls.workspace = true
postgres-protocol.workspace = true
redis.workspace = true

View File

@@ -9,7 +9,6 @@ use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled};
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;
@@ -66,8 +65,6 @@ async fn main() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
let args = cli().get_matches();
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;

View File

@@ -11,10 +11,12 @@ use crate::{
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::{io, net::SocketAddr, time::Duration};
use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres_rustls::MakeRustlsConnect;
use tracing::{error, info, warn};
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -30,7 +32,7 @@ pub enum ConnectionError {
CouldNotConnect(#[from] io::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] native_tls::Error),
TlsError(#[from] InvalidDnsNameError),
#[error("{COULD_NOT_CONNECT}: {0}")]
WakeComputeError(#[from] WakeComputeError),
@@ -257,7 +259,7 @@ pub struct PostgresConnection {
/// Socket connected to a compute node.
pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
tokio::net::TcpStream,
postgres_native_tls::TlsStream<tokio::net::TcpStream>,
tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
>,
/// PostgreSQL connection parameters.
pub params: std::collections::HashMap<String, String>,
@@ -282,12 +284,24 @@ impl ConnCfg {
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
drop(pause);
let tls_connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signed_compute)
.build()
.unwrap();
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
let client_config = if allow_self_signed_compute {
let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
} else {
let root_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
rustls::ClientConfig::builder().with_root_certificates(root_store)
};
let client_config = client_config.with_no_client_auth();
let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
@@ -340,6 +354,50 @@ fn filtered_options(params: &StartupMessageParams) -> Option<String> {
Some(options)
}
#[derive(Debug)]
struct AcceptEverythingVerifier;
impl ServerCertVerifier for AcceptEverythingVerifier {
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
use rustls::SignatureScheme::*;
// The schemes for which `SignatureScheme::supported_in_tls13` returns true.
vec![
ECDSA_NISTP521_SHA512,
ECDSA_NISTP384_SHA384,
ECDSA_NISTP256_SHA256,
RSA_PSS_SHA512,
RSA_PSS_SHA384,
RSA_PSS_SHA256,
ED25519,
]
}
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -51,10 +51,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
) -> Poll<io::Result<usize>> {
let this = self.project();
let mut stream = this.stream;
this.send.put(buf);
ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
this.send.put(buf);
match stream.as_mut().start_send(Frame::binary(this.send.split())) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(io_error(e))),

View File

@@ -22,8 +22,7 @@ serde_with.workspace = true
workspace_hack.workspace = true
utils.workspace = true
async-stream.workspace = true
native-tls.workspace = true
postgres-native-tls.workspace = true
tokio-postgres-rustls.workspace = true
postgres_ffi.workspace = true
tokio-stream.workspace = true
tokio-postgres.workspace = true
@@ -31,6 +30,8 @@ tokio-util = { workspace = true }
futures-util.workspace = true
itertools.workspace = true
camino.workspace = true
rustls.workspace = true
webpki-roots.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
chrono = { workspace = true, default-features = false, features = ["clock", "serde"] }

View File

@@ -71,8 +71,13 @@ pub async fn scan_safekeeper_metadata(
bucket_config.bucket, bucket_config.region, dump_db_table
);
// Use the native TLS implementation (Neon requires TLS)
let tls_connector =
postgres_native_tls::MakeTlsConnector::new(native_tls::TlsConnector::new().unwrap());
let root_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls_connector = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
let (client, connection) = tokio_postgres::connect(&dump_db_connstr, tls_connector).await?;
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.

View File

@@ -287,26 +287,6 @@ async fn timeline_files_handler(request: Request<Body>) -> Result<Response<Body>
.map_err(|e| ApiError::InternalServerError(e.into()))
}
/// Force persist control file and remove old WAL.
async fn timeline_checkpoint_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let ttid = TenantTimelineId::new(
parse_request_param(&request, "tenant_id")?,
parse_request_param(&request, "timeline_id")?,
);
let tli = GlobalTimelines::get(ttid)?;
tli.maybe_persist_control_file(true)
.await
.map_err(ApiError::InternalServerError)?;
tli.remove_old_wal()
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
/// Deactivates the timeline and removes its data directory.
async fn timeline_delete_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
let ttid = TenantTimelineId::new(
@@ -573,10 +553,6 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
"/v1/tenant/:tenant_id/timeline/:timeline_id/control_file",
|r| request_span(r, patch_control_file_handler),
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/checkpoint",
|r| request_span(r, timeline_checkpoint_handler),
)
// for tests
.post("/v1/record_safekeeper_info/:tenant_id/:timeline_id", |r| {
request_span(r, record_safekeeper_info)

View File

@@ -11,7 +11,6 @@ use tracing::info;
use utils::{
id::{TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
pausable_failpoint,
};
use crate::{
@@ -163,8 +162,6 @@ async fn pull_timeline(status: TimelineStatus, host: String) -> Result<Response>
filenames.remove(control_file_index);
filenames.insert(0, "safekeeper.control".to_string());
pausable_failpoint!("sk-pull-timeline-after-list-pausable");
info!(
"downloading {} files from safekeeper {}",
filenames.len(),
@@ -186,13 +183,6 @@ async fn pull_timeline(status: TimelineStatus, host: String) -> Result<Response>
let mut file = tokio::fs::File::create(&file_path).await?;
let mut response = client.get(&http_url).send().await?;
if response.status() != reqwest::StatusCode::OK {
bail!(
"pulling file {} failed: status is {}",
filename,
response.status()
);
}
while let Some(chunk) = response.chunk().await? {
file.write_all(&chunk).await?;
file.flush().await?;

View File

@@ -15,7 +15,7 @@ pub async fn task_main(_conf: SafeKeeperConf) -> anyhow::Result<()> {
for tli in &tlis {
let ttid = tli.ttid;
async {
if let Err(e) = tli.maybe_persist_control_file(false).await {
if let Err(e) = tli.maybe_persist_control_file().await {
warn!("failed to persist control file: {e}");
}
if let Err(e) = tli.remove_old_wal().await {

View File

@@ -827,9 +827,9 @@ where
/// Persist control file if there is something to save and enough time
/// passed after the last save.
pub async fn maybe_persist_inmem_control_file(&mut self, force: bool) -> Result<bool> {
pub async fn maybe_persist_inmem_control_file(&mut self) -> Result<bool> {
const CF_SAVE_INTERVAL: Duration = Duration::from_secs(300);
if !force && self.state.pers.last_persist_at().elapsed() < CF_SAVE_INTERVAL {
if self.state.pers.last_persist_at().elapsed() < CF_SAVE_INTERVAL {
return Ok(false);
}
let need_persist = self.state.inmem.commit_lsn > self.state.commit_lsn

View File

@@ -821,9 +821,9 @@ impl Timeline {
/// passed after the last save. This helps to keep remote_consistent_lsn up
/// to date so that storage nodes restart doesn't cause many pageserver ->
/// safekeeper reconnections.
pub async fn maybe_persist_control_file(self: &Arc<Self>, force: bool) -> Result<()> {
pub async fn maybe_persist_control_file(self: &Arc<Self>) -> Result<()> {
let mut guard = self.write_shared_state().await;
let changed = guard.sk.maybe_persist_inmem_control_file(force).await?;
let changed = guard.sk.maybe_persist_inmem_control_file().await?;
guard.skip_update = !changed;
Ok(())
}

View File

@@ -106,7 +106,7 @@ pub async fn main_task(
if !is_active {
// TODO: maybe use tokio::spawn?
if let Err(e) = tli.maybe_persist_control_file(false).await {
if let Err(e) = tli.maybe_persist_control_file().await {
warn!("control file save in update_status failed: {:?}", e);
}
}

View File

@@ -5,8 +5,6 @@ from typing import Any, Type, TypeVar, Union
T = TypeVar("T", bound="Id")
DEFAULT_WAL_SEG_SIZE = 16 * 1024 * 1024
@total_ordering
class Lsn:
@@ -69,9 +67,6 @@ class Lsn:
def as_int(self) -> int:
return self.lsn_int
def segment_lsn(self, seg_sz: int = DEFAULT_WAL_SEG_SIZE) -> "Lsn":
return Lsn(self.lsn_int - (self.lsn_int % seg_sz))
@dataclass(frozen=True)
class Key:

View File

@@ -3771,7 +3771,7 @@ class SafekeeperPort:
@dataclass
class Safekeeper(LogUtils):
class Safekeeper:
"""An object representing a running safekeeper daemon."""
env: NeonEnv
@@ -3779,13 +3779,6 @@ class Safekeeper(LogUtils):
id: int
running: bool = False
def __init__(self, env: NeonEnv, port: SafekeeperPort, id: int, running: bool = False):
self.env = env
self.port = port
self.id = id
self.running = running
self.logfile = Path(self.data_dir) / f"safekeeper-{id}.log"
def start(self, extra_opts: Optional[List[str]] = None) -> "Safekeeper":
assert self.running is False
self.env.neon_cli.safekeeper_start(self.id, extra_opts=extra_opts)
@@ -3846,38 +3839,11 @@ class Safekeeper(LogUtils):
port=self.port.http, auth_token=auth_token, is_testing_enabled=is_testing_enabled
)
def get_timeline_start_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn:
timeline_status = self.http_client().timeline_status(tenant_id, timeline_id)
timeline_start_lsn = timeline_status.timeline_start_lsn
log.info(f"sk {self.id} timeline start LSN: {timeline_start_lsn}")
return timeline_start_lsn
def data_dir(self) -> str:
return os.path.join(self.env.repo_dir, "safekeepers", f"sk{self.id}")
def get_flush_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn:
timeline_status = self.http_client().timeline_status(tenant_id, timeline_id)
flush_lsn = timeline_status.flush_lsn
log.info(f"sk {self.id} flush LSN: {flush_lsn}")
return flush_lsn
def pull_timeline(
self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId
) -> Dict[str, Any]:
"""
pull_timeline from srcs to self.
"""
src_https = [f"http://localhost:{sk.port.http}" for sk in srcs]
res = self.http_client().pull_timeline(
{"tenant_id": str(tenant_id), "timeline_id": str(timeline_id), "http_hosts": src_https}
)
src_ids = [sk.id for sk in srcs]
log.info(f"finished pulling timeline from {src_ids} to {self.id}")
return res
@property
def data_dir(self) -> Path:
return self.env.repo_dir / "safekeepers" / f"sk{self.id}"
def timeline_dir(self, tenant_id, timeline_id) -> Path:
return self.data_dir / str(tenant_id) / str(timeline_id)
def timeline_dir(self, tenant_id, timeline_id) -> str:
return os.path.join(self.data_dir(), str(tenant_id), str(timeline_id))
def list_segments(self, tenant_id, timeline_id) -> List[str]:
"""
@@ -3890,35 +3856,6 @@ class Safekeeper(LogUtils):
segments.sort()
return segments
def checkpoint_up_to(self, tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn):
"""
Assuming pageserver(s) uploaded to s3 up to `lsn`,
1) wait for remote_consistent_lsn and wal_backup_lsn on safekeeper to reach it.
2) checkpoint timeline on safekeeper, which should remove WAL before this LSN.
"""
cli = self.http_client()
def are_lsns_advanced():
stat = cli.timeline_status(tenant_id, timeline_id)
log.info(
f"waiting for remote_consistent_lsn and backup_lsn on sk {self.id} to reach {lsn}, currently remote_consistent_lsn={stat.remote_consistent_lsn}, backup_lsn={stat.backup_lsn}"
)
assert stat.remote_consistent_lsn >= lsn and stat.backup_lsn >= lsn.segment_lsn()
# xxx: max wait is long because we might be waiting for reconnection from
# pageserver to this safekeeper
wait_until(30, 1, are_lsns_advanced)
cli.checkpoint(tenant_id, timeline_id)
def wait_until_paused(self, failpoint: str):
msg = f"at failpoint {failpoint}"
def paused():
log.info(f"waiting for hitting failpoint {failpoint}")
self.assert_log_contains(msg)
wait_until(20, 0.5, paused)
class S3Scrubber:
def __init__(self, env: NeonEnvBuilder, log_dir: Optional[Path] = None):

View File

@@ -177,13 +177,6 @@ class SafekeeperHttpClient(requests.Session):
)
res.raise_for_status()
def checkpoint(self, tenant_id: TenantId, timeline_id: TimelineId):
res = self.post(
f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/checkpoint",
json={},
)
res.raise_for_status()
# only_local doesn't remove segments in the remote storage.
def timeline_delete(
self, tenant_id: TenantId, timeline_id: TimelineId, only_local: bool = False

View File

@@ -196,7 +196,7 @@ def query_scalar(cur: cursor, query: str) -> Any:
# Traverse directory to get total size.
def get_dir_size(path: Path) -> int:
def get_dir_size(path: str) -> int:
"""Return size in bytes."""
totalbytes = 0
for root, _dirs, files in os.walk(path):
@@ -560,25 +560,3 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: Set[str
elapsed = time.time() - started_at
log.info(f"assert_pageserver_backups_equal completed in {elapsed}s")
class PropagatingThread(threading.Thread):
_target: Any
_args: Any
_kwargs: Any
"""
Simple Thread wrapper with join() propagating the possible exception in the thread.
"""
def run(self):
self.exc = None
try:
self.ret = self._target(*self._args, **self._kwargs)
except BaseException as e:
self.exc = e
def join(self, timeout=None):
super(PropagatingThread, self).join(timeout)
if self.exc:
raise self.exc
return self.ret

View File

@@ -23,6 +23,7 @@ from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.neon_fixtures import (
Endpoint,
NeonEnv,
NeonEnvBuilder,
NeonPageserver,
PgBin,
@@ -47,7 +48,7 @@ from fixtures.remote_storage import (
)
from fixtures.safekeeper.http import SafekeeperHttpClient
from fixtures.safekeeper.utils import are_walreceivers_absent
from fixtures.utils import PropagatingThread, get_dir_size, query_scalar, start_in_background
from fixtures.utils import get_dir_size, query_scalar, start_in_background
def wait_lsn_force_checkpoint(
@@ -359,7 +360,7 @@ def test_wal_removal(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
# We will wait for first segment removal. Make sure they exist for starter.
first_segments = [
sk.timeline_dir(tenant_id, timeline_id) / "000000010000000000000001"
os.path.join(sk.data_dir(), str(tenant_id), str(timeline_id), "000000010000000000000001")
for sk in env.safekeepers
]
assert all(os.path.exists(p) for p in first_segments)
@@ -444,7 +445,7 @@ def is_flush_lsn_caught_up(sk: Safekeeper, tenant_id: TenantId, timeline_id: Tim
def is_wal_trimmed(sk: Safekeeper, tenant_id: TenantId, timeline_id: TimelineId, target_size_mb):
http_cli = sk.http_client()
tli_status = http_cli.timeline_status(tenant_id, timeline_id)
sk_wal_size = get_dir_size(sk.timeline_dir(tenant_id, timeline_id))
sk_wal_size = get_dir_size(os.path.join(sk.data_dir(), str(tenant_id), str(timeline_id)))
sk_wal_size_mb = sk_wal_size / 1024 / 1024
log.info(f"Safekeeper id={sk.id} wal_size={sk_wal_size_mb:.2f}MB status={tli_status}")
return sk_wal_size_mb <= target_size_mb
@@ -590,10 +591,10 @@ def test_s3_wal_replay(neon_env_builder: NeonEnvBuilder):
# save the last (partial) file to put it back after recreation; others will be fetched from s3
sk = env.safekeepers[0]
tli_dir = Path(sk.data_dir) / str(tenant_id) / str(timeline_id)
tli_dir = Path(sk.data_dir()) / str(tenant_id) / str(timeline_id)
f_partial = Path([f for f in os.listdir(tli_dir) if f.endswith(".partial")][0])
f_partial_path = tli_dir / f_partial
f_partial_saved = Path(sk.data_dir) / f_partial.name
f_partial_saved = Path(sk.data_dir()) / f_partial.name
f_partial_path.rename(f_partial_saved)
pg_version = sk.http_client().timeline_status(tenant_id, timeline_id).pg_version
@@ -615,7 +616,7 @@ def test_s3_wal_replay(neon_env_builder: NeonEnvBuilder):
cli = sk.http_client()
cli.timeline_create(tenant_id, timeline_id, pg_version, last_lsn)
f_partial_path = (
Path(sk.data_dir) / str(tenant_id) / str(timeline_id) / f_partial_saved.name
Path(sk.data_dir()) / str(tenant_id) / str(timeline_id) / f_partial_saved.name
)
shutil.copy(f_partial_saved, f_partial_path)
@@ -1131,8 +1132,8 @@ def cmp_sk_wal(sks: List[Safekeeper], tenant_id: TenantId, timeline_id: Timeline
)
for f in mismatch:
f1 = sk0.timeline_dir(tenant_id, timeline_id) / f
f2 = sk.timeline_dir(tenant_id, timeline_id) / f
f1 = os.path.join(sk0.timeline_dir(tenant_id, timeline_id), f)
f2 = os.path.join(sk.timeline_dir(tenant_id, timeline_id), f)
stdout_filename = f"{f2}.filediff"
with open(stdout_filename, "w") as stdout_f:
@@ -1630,7 +1631,7 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
with conn.cursor() as cur:
cur.execute("CREATE TABLE t(key int primary key)")
sk = env.safekeepers[0]
sk_data_dir = sk.data_dir
sk_data_dir = Path(sk.data_dir())
if not auth_enabled:
sk_http = sk.http_client()
sk_http_other = sk_http
@@ -1723,6 +1724,9 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
def test_pull_timeline(neon_env_builder: NeonEnvBuilder):
def safekeepers_guc(env: NeonEnv, sk_names: List[int]) -> str:
return ",".join([f"localhost:{sk.port.pg}" for sk in env.safekeepers if sk.id in sk_names])
def execute_payload(endpoint: Endpoint):
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
@@ -1808,65 +1812,6 @@ def test_pull_timeline(neon_env_builder: NeonEnvBuilder):
show_statuses(env.safekeepers, tenant_id, timeline_id)
# Test pull_timeline while concurrently gc'ing WAL on safekeeper:
# 1) Start pull_timeline, listing files to fetch.
# 2) Write segment, do gc.
# 3) Finish pull_timeline.
# 4) Do some write, verify integrity with timeline_digest.
# Expected to fail while holding off WAL gc plus fetching commit_lsn WAL
# segment is not implemented.
@pytest.mark.xfail
def test_pull_timeline_gc(neon_env_builder: NeonEnvBuilder):
neon_env_builder.num_safekeepers = 3
neon_env_builder.enable_safekeeper_remote_storage(default_remote_storage())
env = neon_env_builder.init_start()
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
(src_sk, dst_sk) = (env.safekeepers[0], env.safekeepers[2])
log.info("use only first 2 safekeepers, 3rd will be seeded")
endpoint = env.endpoints.create("main")
endpoint.active_safekeepers = [1, 2]
endpoint.start()
endpoint.safe_psql("create table t(key int, value text)")
endpoint.safe_psql("insert into t select generate_series(1, 1000), 'pear'")
src_flush_lsn = src_sk.get_flush_lsn(tenant_id, timeline_id)
log.info(f"flush_lsn on src before pull_timeline: {src_flush_lsn}")
dst_http = dst_sk.http_client()
# run pull_timeline which will halt before downloading files
dst_http.configure_failpoints(("sk-pull-timeline-after-list-pausable", "pause"))
pt_handle = PropagatingThread(
target=dst_sk.pull_timeline, args=([src_sk], tenant_id, timeline_id)
)
pt_handle.start()
dst_sk.wait_until_paused("sk-pull-timeline-after-list-pausable")
# ensure segment exists
endpoint.safe_psql("insert into t select generate_series(1, 180000), 'papaya'")
lsn = last_flush_lsn_upload(env, endpoint, tenant_id, timeline_id)
assert lsn > Lsn("0/2000000")
# Checkpoint timeline beyond lsn.
src_sk.checkpoint_up_to(tenant_id, timeline_id, lsn)
first_segment_p = src_sk.timeline_dir(tenant_id, timeline_id) / "000000010000000000000001"
log.info(f"first segment exist={os.path.exists(first_segment_p)}")
dst_http.configure_failpoints(("sk-pull-timeline-after-list-pausable", "off"))
pt_handle.join()
timeline_start_lsn = src_sk.get_timeline_start_lsn(tenant_id, timeline_id)
dst_flush_lsn = dst_sk.get_flush_lsn(tenant_id, timeline_id)
log.info(f"flush_lsn on dst after pull_timeline: {dst_flush_lsn}")
assert dst_flush_lsn >= src_flush_lsn
digests = [
sk.http_client().timeline_digest(tenant_id, timeline_id, timeline_start_lsn, dst_flush_lsn)
for sk in [src_sk, dst_sk]
]
assert digests[0] == digests[1], f"digest on src is {digests[0]} but on dst is {digests[1]}"
# In this test we check for excessive START_REPLICATION and START_WAL_PUSH queries
# when compute is active, but there are no writes to the timeline. In that case
# pageserver should maintain a single connection to safekeeper and don't attempt