Compare commits

..

2 Commits

Author SHA1 Message Date
Anna Khanova
e39f0a2347 Change default timeout 2024-02-10 11:49:45 +01:00
Anna Khanova
8bf1aacb24 Create timeout on proxy<->cplane communication 2024-02-10 00:45:36 +01:00
156 changed files with 1262 additions and 7192 deletions

View File

@@ -17,7 +17,6 @@ concurrency:
jobs:
actionlint:
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

View File

@@ -26,8 +26,8 @@ env:
jobs:
check-permissions:
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
runs-on: ubuntu-latest
steps:
- name: Disallow PRs from forks
if: |

View File

@@ -117,7 +117,6 @@ jobs:
check-linux-arm-build:
timeout-minutes: 90
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
runs-on: [ self-hosted, dev, arm64 ]
env:
@@ -238,7 +237,6 @@ jobs:
check-codestyle-rust-arm:
timeout-minutes: 90
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
runs-on: [ self-hosted, dev, arm64 ]
container:

32
Cargo.lock generated
View File

@@ -1639,22 +1639,6 @@ dependencies = [
"rusticata-macros",
]
[[package]]
name = "desim"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"hex",
"parking_lot 0.12.1",
"rand 0.8.5",
"scopeguard",
"smallvec",
"tracing",
"utils",
"workspace_hack",
]
[[package]]
name = "diesel"
version = "2.1.4"
@@ -2263,11 +2247,11 @@ dependencies = [
[[package]]
name = "hashlink"
version = "0.8.4"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
dependencies = [
"hashbrown 0.14.0",
"hashbrown 0.13.2",
]
[[package]]
@@ -3952,7 +3936,6 @@ dependencies = [
"pin-project-lite",
"postgres-protocol",
"rand 0.8.5",
"serde",
"thiserror",
"tokio",
"tracing",
@@ -4844,7 +4827,6 @@ dependencies = [
"clap",
"const_format",
"crc32c",
"desim",
"fail",
"fs2",
"futures",
@@ -4860,7 +4842,6 @@ dependencies = [
"postgres_backend",
"postgres_ffi",
"pq_proto",
"rand 0.8.5",
"regex",
"remote_storage",
"reqwest",
@@ -4881,10 +4862,8 @@ dependencies = [
"tokio-util",
"toml_edit",
"tracing",
"tracing-subscriber",
"url",
"utils",
"walproposer",
"workspace_hack",
]
@@ -5761,7 +5740,7 @@ dependencies = [
[[package]]
name = "tokio-epoll-uring"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
dependencies = [
"futures",
"nix 0.26.4",
@@ -6286,9 +6265,8 @@ dependencies = [
[[package]]
name = "uring-common"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
dependencies = [
"bytes",
"io-uring",
"libc",
]

View File

@@ -18,7 +18,6 @@ members = [
"libs/pageserver_api",
"libs/postgres_ffi",
"libs/safekeeper_api",
"libs/desim",
"libs/utils",
"libs/consumption_metrics",
"libs/postgres_backend",
@@ -81,7 +80,7 @@ futures-core = "0.3"
futures-util = "0.3"
git-version = "0.3"
hashbrown = "0.13"
hashlink = "0.8.4"
hashlink = "0.8.1"
hdrhistogram = "7.5.2"
hex = "0.4"
hex-literal = "0.4"
@@ -204,7 +203,6 @@ postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
remote_storage = { version = "0.1", path = "./libs/remote_storage/" }
safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
desim = { version = "0.1", path = "./libs/desim" }
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::env;
use std::fs;
use std::io::BufRead;
use std::os::unix::fs::{symlink, PermissionsExt};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
@@ -634,48 +634,6 @@ impl ComputeNode {
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path)?;
// Place pg_dynshmem under /dev/shm. This allows us to use
// 'dynamic_shared_memory_type = mmap' so that the files are placed in
// /dev/shm, similar to how 'dynamic_shared_memory_type = posix' works.
//
// Why on earth don't we just stick to the 'posix' default, you might
// ask. It turns out that making large allocations with 'posix' doesn't
// work very well with autoscaling. The behavior we want is that:
//
// 1. You can make large DSM allocations, larger than the current RAM
// size of the VM, without errors
//
// 2. If the allocated memory is really used, the VM is scaled up
// automatically to accommodate that
//
// We try to make that possible by having swap in the VM. But with the
// default 'posix' DSM implementation, we fail step 1, even when there's
// plenty of swap available. PostgreSQL uses posix_fallocate() to create
// the shmem segment, which is really just a file in /dev/shm in Linux,
// but posix_fallocate() on tmpfs returns ENOMEM if the size is larger
// than available RAM.
//
// Using 'dynamic_shared_memory_type = mmap' works around that, because
// the Postgres 'mmap' DSM implementation doesn't use
// posix_fallocate(). Instead, it uses repeated calls to write(2) to
// fill the file with zeros. It's weird that that differs between
// 'posix' and 'mmap', but we take advantage of it. When the file is
// filled slowly with write(2), the kernel allows it to grow larger, as
// long as there's swap available.
//
// In short, using 'dynamic_shared_memory_type = mmap' allows us one DSM
// segment to be larger than currently available RAM. But because we
// don't want to store it on a real file, which the kernel would try to
// flush to disk, so symlink pg_dynshm to /dev/shm.
//
// We don't set 'dynamic_shared_memory_type = mmap' here, we let the
// control plane control that option. If 'mmap' is not used, this
// symlink doesn't affect anything.
//
// See https://github.com/neondatabase/autoscaling/issues/800
std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?;
symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?;
match spec.mode {
ComputeMode::Primary => {}
ComputeMode::Replica | ComputeMode::Static(..) => {

View File

@@ -280,12 +280,6 @@ async fn handle_node_list(req: Request<Body>) -> Result<Response<Body>, ApiError
json_response(StatusCode::OK, state.service.node_list().await?)
}
async fn handle_node_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?;
json_response(StatusCode::OK, state.service.node_drop(node_id).await?)
}
async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let node_id: NodeId = parse_request_param(&req, "node_id")?;
let config_req = json_request::<NodeConfigureRequest>(&mut req).await?;
@@ -326,13 +320,6 @@ async fn handle_tenant_shard_migrate(
)
}
async fn handle_tenant_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let state = get_state(&req);
json_response(StatusCode::OK, state.service.tenant_drop(tenant_id).await?)
}
/// Status endpoint is just used for checking that our HTTP listener is up
async fn handle_status(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, ())
@@ -415,12 +402,6 @@ pub fn make_router(
request_span(r, handle_attach_hook)
})
.post("/debug/v1/inspect", |r| request_span(r, handle_inspect))
.post("/debug/v1/tenant/:tenant_id/drop", |r| {
request_span(r, handle_tenant_drop)
})
.post("/debug/v1/node/:node_id/drop", |r| {
request_span(r, handle_node_drop)
})
.get("/control/v1/tenant/:tenant_id/locate", |r| {
tenant_service_handler(r, handle_tenant_locate)
})

View File

@@ -260,6 +260,7 @@ impl Persistence {
/// Ordering: call this _after_ deleting the tenant on pageservers, but _before_ dropping state for
/// the tenant from memory on this server.
#[allow(unused)]
pub(crate) async fn delete_tenant(&self, del_tenant_id: TenantId) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
@@ -272,18 +273,6 @@ impl Persistence {
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)?;
Ok(())
})
.await
}
/// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient
/// batched increment of the generations of all tenants whose generation_pageserver is equal to
/// the node that called /re-attach.

View File

@@ -1804,45 +1804,6 @@ impl Service {
Ok(TenantShardMigrateResponse {})
}
/// This is for debug/support only: we simply drop all state for a tenant, without
/// detaching or deleting it on pageservers.
pub(crate) async fn tenant_drop(&self, tenant_id: TenantId) -> Result<(), ApiError> {
self.persistence.delete_tenant(tenant_id).await?;
let mut locked = self.inner.write().unwrap();
let mut shards = Vec::new();
for (tenant_shard_id, _) in locked.tenants.range(TenantShardId::tenant_range(tenant_id)) {
shards.push(*tenant_shard_id);
}
for shard in shards {
locked.tenants.remove(&shard);
}
Ok(())
}
/// This is for debug/support only: we simply drop all state for a tenant, without
/// detaching or deleting it on pageservers. We do not try and re-schedule any
/// tenants that were on this node.
///
/// TODO: proper node deletion API that unhooks things more gracefully
pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> {
self.persistence.delete_node(node_id).await?;
let mut locked = self.inner.write().unwrap();
for shard in locked.tenants.values_mut() {
shard.deref_node(node_id);
}
let mut nodes = (*locked.nodes).clone();
nodes.remove(&node_id);
locked.nodes = Arc::new(nodes);
Ok(())
}
pub(crate) async fn node_list(&self) -> Result<Vec<NodePersistence>, ApiError> {
// It is convenient to avoid taking the big lock and converting Node to a serializable
// structure, by fetching from storage instead of reading in-memory state.

View File

@@ -534,18 +534,4 @@ impl TenantState {
seq: self.sequence,
})
}
// If we had any state at all referring to this node ID, drop it. Does not
// attempt to reschedule.
pub(crate) fn deref_node(&mut self, node_id: NodeId) {
if self.intent.attached == Some(node_id) {
self.intent.attached = None;
}
self.intent.secondary.retain(|n| n != &node_id);
self.observed.locations.remove(&node_id);
debug_assert!(!self.intent.all_pageservers().contains(&node_id));
}
}

View File

@@ -1,18 +0,0 @@
[package]
name = "desim"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
rand.workspace = true
tracing.workspace = true
bytes.workspace = true
utils.workspace = true
parking_lot.workspace = true
hex.workspace = true
scopeguard.workspace = true
smallvec = { workspace = true, features = ["write"] }
workspace_hack.workspace = true

View File

@@ -1,7 +0,0 @@
# Discrete Event SIMulator
This is a library for running simulations of distributed systems. The main idea is borrowed from [FoundationDB](https://www.youtube.com/watch?v=4fFDFbi3toc).
Each node runs as a separate thread. This library was not optimized for speed yet, but it's already much faster than running usual intergration tests in real time, because it uses virtual simulation time and can fast-forward time to skip intervals where all nodes are doing nothing but sleeping or waiting for something.
The original purpose for this library is to test walproposer and safekeeper implementation working together, in a scenarios close to the real world environment. This simulator is determenistic and can inject failures in networking without waiting minutes of wall-time to trigger timeout, which makes it easier to find bugs in our consensus implementation compared to using integration tests.

View File

@@ -1,108 +0,0 @@
use std::{collections::VecDeque, sync::Arc};
use parking_lot::{Mutex, MutexGuard};
use crate::executor::{self, PollSome, Waker};
/// FIFO channel with blocking send and receive. Can be cloned and shared between threads.
/// Blocking functions should be used only from threads that are managed by the executor.
pub struct Chan<T> {
shared: Arc<State<T>>,
}
impl<T> Clone for Chan<T> {
fn clone(&self) -> Self {
Chan {
shared: self.shared.clone(),
}
}
}
impl<T> Default for Chan<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Chan<T> {
pub fn new() -> Chan<T> {
Chan {
shared: Arc::new(State {
queue: Mutex::new(VecDeque::new()),
waker: Waker::new(),
}),
}
}
/// Get a message from the front of the queue, block if the queue is empty.
/// If not called from the executor thread, it can block forever.
pub fn recv(&self) -> T {
self.shared.recv()
}
/// Panic if the queue is empty.
pub fn must_recv(&self) -> T {
self.shared
.try_recv()
.expect("message should've been ready")
}
/// Get a message from the front of the queue, return None if the queue is empty.
/// Never blocks.
pub fn try_recv(&self) -> Option<T> {
self.shared.try_recv()
}
/// Send a message to the back of the queue.
pub fn send(&self, t: T) {
self.shared.send(t);
}
}
struct State<T> {
queue: Mutex<VecDeque<T>>,
waker: Waker,
}
impl<T> State<T> {
fn send(&self, t: T) {
self.queue.lock().push_back(t);
self.waker.wake_all();
}
fn try_recv(&self) -> Option<T> {
let mut q = self.queue.lock();
q.pop_front()
}
fn recv(&self) -> T {
// interrupt the receiver to prevent consuming everything at once
executor::yield_me(0);
let mut queue = self.queue.lock();
if let Some(t) = queue.pop_front() {
return t;
}
loop {
self.waker.wake_me_later();
if let Some(t) = queue.pop_front() {
return t;
}
MutexGuard::unlocked(&mut queue, || {
executor::yield_me(-1);
});
}
}
}
impl<T> PollSome for Chan<T> {
/// Schedules a wakeup for the current thread.
fn wake_me(&self) {
self.shared.waker.wake_me_later();
}
/// Checks if chan has any pending messages.
fn has_some(&self) -> bool {
!self.shared.queue.lock().is_empty()
}
}

View File

@@ -1,483 +0,0 @@
use std::{
panic::AssertUnwindSafe,
sync::{
atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering},
mpsc, Arc, OnceLock,
},
thread::JoinHandle,
};
use tracing::{debug, error, trace};
use crate::time::Timing;
/// Stores status of the running threads. Threads are registered in the runtime upon creation
/// and deregistered upon termination.
pub struct Runtime {
// stores handles to all threads that are currently running
threads: Vec<ThreadHandle>,
// stores current time and pending wakeups
clock: Arc<Timing>,
// thread counter
thread_counter: AtomicU32,
// Thread step counter -- how many times all threads has been actually
// stepped (note that all world/time/executor/thread have slightly different
// meaning of steps). For observability.
pub step_counter: u64,
}
impl Runtime {
/// Init new runtime, no running threads.
pub fn new(clock: Arc<Timing>) -> Self {
Self {
threads: Vec::new(),
clock,
thread_counter: AtomicU32::new(0),
step_counter: 0,
}
}
/// Spawn a new thread and register it in the runtime.
pub fn spawn<F>(&mut self, f: F) -> ExternalHandle
where
F: FnOnce() + Send + 'static,
{
let (tx, rx) = mpsc::channel();
let clock = self.clock.clone();
let tid = self.thread_counter.fetch_add(1, Ordering::SeqCst);
debug!("spawning thread-{}", tid);
let join = std::thread::spawn(move || {
let _guard = tracing::info_span!("", tid).entered();
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
with_thread_context(|ctx| {
assert!(ctx.clock.set(clock).is_ok());
ctx.id.store(tid, Ordering::SeqCst);
tx.send(ctx.clone()).expect("failed to send thread context");
// suspend thread to put it to `threads` in sleeping state
ctx.yield_me(0);
});
// start user-provided function
f();
}));
debug!("thread finished");
if let Err(e) = res {
with_thread_context(|ctx| {
if !ctx.allow_panic.load(std::sync::atomic::Ordering::SeqCst) {
error!("thread panicked, terminating the process: {:?}", e);
std::process::exit(1);
}
debug!("thread panicked: {:?}", e);
let mut result = ctx.result.lock();
if result.0 == -1 {
*result = (256, format!("thread panicked: {:?}", e));
}
});
}
with_thread_context(|ctx| {
ctx.finish_me();
});
});
let ctx = rx.recv().expect("failed to receive thread context");
let handle = ThreadHandle::new(ctx.clone(), join);
self.threads.push(handle);
ExternalHandle { ctx }
}
/// Returns true if there are any unfinished activity, such as running thread or pending events.
/// Otherwise returns false, which means all threads are blocked forever.
pub fn step(&mut self) -> bool {
trace!("runtime step");
// have we run any thread?
let mut ran = false;
self.threads.retain(|thread: &ThreadHandle| {
let res = thread.ctx.wakeup.compare_exchange(
PENDING_WAKEUP,
NO_WAKEUP,
Ordering::SeqCst,
Ordering::SeqCst,
);
if res.is_err() {
// thread has no pending wakeups, leaving as is
return true;
}
ran = true;
trace!("entering thread-{}", thread.ctx.tid());
let status = thread.step();
self.step_counter += 1;
trace!(
"out of thread-{} with status {:?}",
thread.ctx.tid(),
status
);
if status == Status::Sleep {
true
} else {
trace!("thread has finished");
// removing the thread from the list
false
}
});
if !ran {
trace!("no threads were run, stepping clock");
if let Some(ctx_to_wake) = self.clock.step() {
trace!("waking up thread-{}", ctx_to_wake.tid());
ctx_to_wake.inc_wake();
} else {
return false;
}
}
true
}
/// Kill all threads. This is done by setting a flag in each thread context and waking it up.
pub fn crash_all_threads(&mut self) {
for thread in self.threads.iter() {
thread.ctx.crash_stop();
}
// all threads should be finished after a few steps
while !self.threads.is_empty() {
self.step();
}
}
}
impl Drop for Runtime {
fn drop(&mut self) {
debug!("dropping the runtime");
self.crash_all_threads();
}
}
#[derive(Clone)]
pub struct ExternalHandle {
ctx: Arc<ThreadContext>,
}
impl ExternalHandle {
/// Returns true if thread has finished execution.
pub fn is_finished(&self) -> bool {
let status = self.ctx.mutex.lock();
*status == Status::Finished
}
/// Returns exitcode and message, which is available after thread has finished execution.
pub fn result(&self) -> (i32, String) {
let result = self.ctx.result.lock();
result.clone()
}
/// Returns thread id.
pub fn id(&self) -> u32 {
self.ctx.id.load(Ordering::SeqCst)
}
/// Sets a flag to crash thread on the next wakeup.
pub fn crash_stop(&self) {
self.ctx.crash_stop();
}
}
struct ThreadHandle {
ctx: Arc<ThreadContext>,
_join: JoinHandle<()>,
}
impl ThreadHandle {
/// Create a new [`ThreadHandle`] and wait until thread will enter [`Status::Sleep`] state.
fn new(ctx: Arc<ThreadContext>, join: JoinHandle<()>) -> Self {
let mut status = ctx.mutex.lock();
// wait until thread will go into the first yield
while *status != Status::Sleep {
ctx.condvar.wait(&mut status);
}
drop(status);
Self { ctx, _join: join }
}
/// Allows thread to execute one step of its execution.
/// Returns [`Status`] of the thread after the step.
fn step(&self) -> Status {
let mut status = self.ctx.mutex.lock();
assert!(matches!(*status, Status::Sleep));
*status = Status::Running;
self.ctx.condvar.notify_all();
while *status == Status::Running {
self.ctx.condvar.wait(&mut status);
}
*status
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Status {
/// Thread is running.
Running,
/// Waiting for event to complete, will be resumed by the executor step, once wakeup flag is set.
Sleep,
/// Thread finished execution.
Finished,
}
const NO_WAKEUP: u8 = 0;
const PENDING_WAKEUP: u8 = 1;
pub struct ThreadContext {
id: AtomicU32,
// used to block thread until it is woken up
mutex: parking_lot::Mutex<Status>,
condvar: parking_lot::Condvar,
// used as a flag to indicate runtime that thread is ready to be woken up
wakeup: AtomicU8,
clock: OnceLock<Arc<Timing>>,
// execution result, set by exit() call
result: parking_lot::Mutex<(i32, String)>,
// determines if process should be killed on receiving panic
allow_panic: AtomicBool,
// acts as a signal that thread should crash itself on the next wakeup
crash_request: AtomicBool,
}
impl ThreadContext {
pub(crate) fn new() -> Self {
Self {
id: AtomicU32::new(0),
mutex: parking_lot::Mutex::new(Status::Running),
condvar: parking_lot::Condvar::new(),
wakeup: AtomicU8::new(NO_WAKEUP),
clock: OnceLock::new(),
result: parking_lot::Mutex::new((-1, String::new())),
allow_panic: AtomicBool::new(false),
crash_request: AtomicBool::new(false),
}
}
}
// Functions for executor to control thread execution.
impl ThreadContext {
/// Set atomic flag to indicate that thread is ready to be woken up.
fn inc_wake(&self) {
self.wakeup.store(PENDING_WAKEUP, Ordering::SeqCst);
}
/// Internal function used for event queues.
pub(crate) fn schedule_wakeup(self: &Arc<Self>, after_ms: u64) {
self.clock
.get()
.unwrap()
.schedule_wakeup(after_ms, self.clone());
}
fn tid(&self) -> u32 {
self.id.load(Ordering::SeqCst)
}
fn crash_stop(&self) {
let status = self.mutex.lock();
if *status == Status::Finished {
debug!(
"trying to crash thread-{}, which is already finished",
self.tid()
);
return;
}
assert!(matches!(*status, Status::Sleep));
drop(status);
self.allow_panic.store(true, Ordering::SeqCst);
self.crash_request.store(true, Ordering::SeqCst);
// set a wakeup
self.inc_wake();
// it will panic on the next wakeup
}
}
// Internal functions.
impl ThreadContext {
/// Blocks thread until it's woken up by the executor. If `after_ms` is 0, is will be
/// woken on the next step. If `after_ms` > 0, wakeup is scheduled after that time.
/// Otherwise wakeup is not scheduled inside `yield_me`, and should be arranged before
/// calling this function.
fn yield_me(self: &Arc<Self>, after_ms: i64) {
let mut status = self.mutex.lock();
assert!(matches!(*status, Status::Running));
match after_ms.cmp(&0) {
std::cmp::Ordering::Less => {
// block until something wakes us up
}
std::cmp::Ordering::Equal => {
// tell executor that we are ready to be woken up
self.inc_wake();
}
std::cmp::Ordering::Greater => {
// schedule wakeup
self.clock
.get()
.unwrap()
.schedule_wakeup(after_ms as u64, self.clone());
}
}
*status = Status::Sleep;
self.condvar.notify_all();
// wait until executor wakes us up
while *status != Status::Running {
self.condvar.wait(&mut status);
}
if self.crash_request.load(Ordering::SeqCst) {
panic!("crashed by request");
}
}
/// Called only once, exactly before thread finishes execution.
fn finish_me(&self) {
let mut status = self.mutex.lock();
assert!(matches!(*status, Status::Running));
*status = Status::Finished;
{
let mut result = self.result.lock();
if result.0 == -1 {
*result = (0, "finished normally".to_owned());
}
}
self.condvar.notify_all();
}
}
/// Invokes the given closure with a reference to the current thread [`ThreadContext`].
#[inline(always)]
fn with_thread_context<T>(f: impl FnOnce(&Arc<ThreadContext>) -> T) -> T {
thread_local!(static THREAD_DATA: Arc<ThreadContext> = Arc::new(ThreadContext::new()));
THREAD_DATA.with(f)
}
/// Waker is used to wake up threads that are blocked on condition.
/// It keeps track of contexts [`Arc<ThreadContext>`] and can increment the counter
/// of several contexts to send a notification.
pub struct Waker {
// contexts that are waiting for a notification
contexts: parking_lot::Mutex<smallvec::SmallVec<[Arc<ThreadContext>; 8]>>,
}
impl Default for Waker {
fn default() -> Self {
Self::new()
}
}
impl Waker {
pub fn new() -> Self {
Self {
contexts: parking_lot::Mutex::new(smallvec::SmallVec::new()),
}
}
/// Subscribe current thread to receive a wake notification later.
pub fn wake_me_later(&self) {
with_thread_context(|ctx| {
self.contexts.lock().push(ctx.clone());
});
}
/// Wake up all threads that are waiting for a notification and clear the list.
pub fn wake_all(&self) {
let mut v = self.contexts.lock();
for ctx in v.iter() {
ctx.inc_wake();
}
v.clear();
}
}
/// See [`ThreadContext::yield_me`].
pub fn yield_me(after_ms: i64) {
with_thread_context(|ctx| ctx.yield_me(after_ms))
}
/// Get current time.
pub fn now() -> u64 {
with_thread_context(|ctx| ctx.clock.get().unwrap().now())
}
pub fn exit(code: i32, msg: String) {
with_thread_context(|ctx| {
ctx.allow_panic.store(true, Ordering::SeqCst);
let mut result = ctx.result.lock();
*result = (code, msg);
panic!("exit");
});
}
pub(crate) fn get_thread_ctx() -> Arc<ThreadContext> {
with_thread_context(|ctx| ctx.clone())
}
/// Trait for polling channels until they have something.
pub trait PollSome {
/// Schedule wakeup for message arrival.
fn wake_me(&self);
/// Check if channel has a ready message.
fn has_some(&self) -> bool;
}
/// Blocks current thread until one of the channels has a ready message. Returns
/// index of the channel that has a message. If timeout is reached, returns None.
///
/// Negative timeout means block forever. Zero timeout means check channels and return
/// immediately. Positive timeout means block until timeout is reached.
pub fn epoll_chans(chans: &[Box<dyn PollSome>], timeout: i64) -> Option<usize> {
let deadline = if timeout < 0 {
0
} else {
now() + timeout as u64
};
loop {
for chan in chans {
chan.wake_me()
}
for (i, chan) in chans.iter().enumerate() {
if chan.has_some() {
return Some(i);
}
}
if timeout < 0 {
// block until wakeup
yield_me(-1);
} else {
let current_time = now();
if current_time >= deadline {
return None;
}
yield_me((deadline - current_time) as i64);
}
}
}

View File

@@ -1,8 +0,0 @@
pub mod chan;
pub mod executor;
pub mod network;
pub mod node_os;
pub mod options;
pub mod proto;
pub mod time;
pub mod world;

View File

@@ -1,451 +0,0 @@
use std::{
cmp::Ordering,
collections::{BinaryHeap, VecDeque},
fmt::{self, Debug},
ops::DerefMut,
sync::{mpsc, Arc},
};
use parking_lot::{
lock_api::{MappedMutexGuard, MutexGuard},
Mutex, RawMutex,
};
use rand::rngs::StdRng;
use tracing::debug;
use crate::{
executor::{self, ThreadContext},
options::NetworkOptions,
proto::NetEvent,
proto::NodeEvent,
};
use super::{chan::Chan, proto::AnyMessage};
pub struct NetworkTask {
options: Arc<NetworkOptions>,
connections: Mutex<Vec<VirtualConnection>>,
/// min-heap of connections having something to deliver.
events: Mutex<BinaryHeap<Event>>,
task_context: Arc<ThreadContext>,
}
impl NetworkTask {
pub fn start_new(options: Arc<NetworkOptions>, tx: mpsc::Sender<Arc<NetworkTask>>) {
let ctx = executor::get_thread_ctx();
let task = Arc::new(Self {
options,
connections: Mutex::new(Vec::new()),
events: Mutex::new(BinaryHeap::new()),
task_context: ctx,
});
// send the task upstream
tx.send(task.clone()).unwrap();
// start the task
task.start();
}
pub fn start_new_connection(self: &Arc<Self>, rng: StdRng, dst_accept: Chan<NodeEvent>) -> TCP {
let now = executor::now();
let connection_id = self.connections.lock().len();
let vc = VirtualConnection {
connection_id,
dst_accept,
dst_sockets: [Chan::new(), Chan::new()],
state: Mutex::new(ConnectionState {
buffers: [NetworkBuffer::new(None), NetworkBuffer::new(Some(now))],
rng,
}),
};
vc.schedule_timeout(self);
vc.send_connect(self);
let recv_chan = vc.dst_sockets[0].clone();
self.connections.lock().push(vc);
TCP {
net: self.clone(),
conn_id: connection_id,
dir: 0,
recv_chan,
}
}
}
// private functions
impl NetworkTask {
/// Schedule to wakeup network task (self) `after_ms` later to deliver
/// messages of connection `id`.
fn schedule(&self, id: usize, after_ms: u64) {
self.events.lock().push(Event {
time: executor::now() + after_ms,
conn_id: id,
});
self.task_context.schedule_wakeup(after_ms);
}
/// Get locked connection `id`.
fn get(&self, id: usize) -> MappedMutexGuard<'_, RawMutex, VirtualConnection> {
MutexGuard::map(self.connections.lock(), |connections| {
connections.get_mut(id).unwrap()
})
}
fn collect_pending_events(&self, now: u64, vec: &mut Vec<Event>) {
vec.clear();
let mut events = self.events.lock();
while let Some(event) = events.peek() {
if event.time > now {
break;
}
let event = events.pop().unwrap();
vec.push(event);
}
}
fn start(self: &Arc<Self>) {
debug!("started network task");
let mut events = Vec::new();
loop {
let now = executor::now();
self.collect_pending_events(now, &mut events);
for event in events.drain(..) {
let conn = self.get(event.conn_id);
conn.process(self);
}
// block until wakeup
executor::yield_me(-1);
}
}
}
// 0 - from node(0) to node(1)
// 1 - from node(1) to node(0)
type MessageDirection = u8;
fn sender_str(dir: MessageDirection) -> &'static str {
match dir {
0 => "client",
1 => "server",
_ => unreachable!(),
}
}
fn receiver_str(dir: MessageDirection) -> &'static str {
match dir {
0 => "server",
1 => "client",
_ => unreachable!(),
}
}
/// Virtual connection between two nodes.
/// Node 0 is the creator of the connection (client),
/// and node 1 is the acceptor (server).
struct VirtualConnection {
connection_id: usize,
/// one-off chan, used to deliver Accept message to dst
dst_accept: Chan<NodeEvent>,
/// message sinks
dst_sockets: [Chan<NetEvent>; 2],
state: Mutex<ConnectionState>,
}
struct ConnectionState {
buffers: [NetworkBuffer; 2],
rng: StdRng,
}
impl VirtualConnection {
/// Notify the future about the possible timeout.
fn schedule_timeout(&self, net: &NetworkTask) {
if let Some(timeout) = net.options.keepalive_timeout {
net.schedule(self.connection_id, timeout);
}
}
/// Send the handshake (Accept) to the server.
fn send_connect(&self, net: &NetworkTask) {
let now = executor::now();
let mut state = self.state.lock();
let delay = net.options.connect_delay.delay(&mut state.rng);
let buffer = &mut state.buffers[0];
assert!(buffer.buf.is_empty());
assert!(!buffer.recv_closed);
assert!(!buffer.send_closed);
assert!(buffer.last_recv.is_none());
let delay = if let Some(ms) = delay {
ms
} else {
debug!("NET: TCP #{} dropped connect", self.connection_id);
buffer.send_closed = true;
return;
};
// Send a message into the future.
buffer
.buf
.push_back((now + delay, AnyMessage::InternalConnect));
net.schedule(self.connection_id, delay);
}
/// Transmit some of the messages from the buffer to the nodes.
fn process(&self, net: &Arc<NetworkTask>) {
let now = executor::now();
let mut state = self.state.lock();
for direction in 0..2 {
self.process_direction(
net,
state.deref_mut(),
now,
direction as MessageDirection,
&self.dst_sockets[direction ^ 1],
);
}
// Close the one side of the connection by timeout if the node
// has not received any messages for a long time.
if let Some(timeout) = net.options.keepalive_timeout {
let mut to_close = [false, false];
for direction in 0..2 {
let buffer = &mut state.buffers[direction];
if buffer.recv_closed {
continue;
}
if let Some(last_recv) = buffer.last_recv {
if now - last_recv >= timeout {
debug!(
"NET: connection {} timed out at {}",
self.connection_id,
receiver_str(direction as MessageDirection)
);
let node_idx = direction ^ 1;
to_close[node_idx] = true;
}
}
}
drop(state);
for (node_idx, should_close) in to_close.iter().enumerate() {
if *should_close {
self.close(node_idx);
}
}
}
}
/// Process messages in the buffer in the given direction.
fn process_direction(
&self,
net: &Arc<NetworkTask>,
state: &mut ConnectionState,
now: u64,
direction: MessageDirection,
to_socket: &Chan<NetEvent>,
) {
let buffer = &mut state.buffers[direction as usize];
if buffer.recv_closed {
assert!(buffer.buf.is_empty());
}
while !buffer.buf.is_empty() && buffer.buf.front().unwrap().0 <= now {
let msg = buffer.buf.pop_front().unwrap().1;
buffer.last_recv = Some(now);
self.schedule_timeout(net);
if let AnyMessage::InternalConnect = msg {
// TODO: assert to_socket is the server
let server_to_client = TCP {
net: net.clone(),
conn_id: self.connection_id,
dir: direction ^ 1,
recv_chan: to_socket.clone(),
};
// special case, we need to deliver new connection to a separate channel
self.dst_accept.send(NodeEvent::Accept(server_to_client));
} else {
to_socket.send(NetEvent::Message(msg));
}
}
}
/// Try to send a message to the buffer, optionally dropping it and
/// determining delivery timestamp.
fn send(&self, net: &NetworkTask, direction: MessageDirection, msg: AnyMessage) {
let now = executor::now();
let mut state = self.state.lock();
let (delay, close) = if let Some(ms) = net.options.send_delay.delay(&mut state.rng) {
(ms, false)
} else {
(0, true)
};
let buffer = &mut state.buffers[direction as usize];
if buffer.send_closed {
debug!(
"NET: TCP #{} dropped message {:?} (broken pipe)",
self.connection_id, msg
);
return;
}
if close {
debug!(
"NET: TCP #{} dropped message {:?} (pipe just broke)",
self.connection_id, msg
);
buffer.send_closed = true;
return;
}
if buffer.recv_closed {
debug!(
"NET: TCP #{} dropped message {:?} (recv closed)",
self.connection_id, msg
);
return;
}
// Send a message into the future.
buffer.buf.push_back((now + delay, msg));
net.schedule(self.connection_id, delay);
}
/// Close the connection. Only one side of the connection will be closed,
/// and no further messages will be delivered. The other side will not be notified.
fn close(&self, node_idx: usize) {
let mut state = self.state.lock();
let recv_buffer = &mut state.buffers[1 ^ node_idx];
if recv_buffer.recv_closed {
debug!(
"NET: TCP #{} closed twice at {}",
self.connection_id,
sender_str(node_idx as MessageDirection),
);
return;
}
debug!(
"NET: TCP #{} closed at {}",
self.connection_id,
sender_str(node_idx as MessageDirection),
);
recv_buffer.recv_closed = true;
for msg in recv_buffer.buf.drain(..) {
debug!(
"NET: TCP #{} dropped message {:?} (closed)",
self.connection_id, msg
);
}
let send_buffer = &mut state.buffers[node_idx];
send_buffer.send_closed = true;
drop(state);
// TODO: notify the other side?
self.dst_sockets[node_idx].send(NetEvent::Closed);
}
}
struct NetworkBuffer {
/// Messages paired with time of delivery
buf: VecDeque<(u64, AnyMessage)>,
/// True if the connection is closed on the receiving side,
/// i.e. no more messages from the buffer will be delivered.
recv_closed: bool,
/// True if the connection is closed on the sending side,
/// i.e. no more messages will be added to the buffer.
send_closed: bool,
/// Last time a message was delivered from the buffer.
/// If None, it means that the server is the receiver and
/// it has not yet aware of this connection (i.e. has not
/// received the Accept).
last_recv: Option<u64>,
}
impl NetworkBuffer {
fn new(last_recv: Option<u64>) -> Self {
Self {
buf: VecDeque::new(),
recv_closed: false,
send_closed: false,
last_recv,
}
}
}
/// Single end of a bidirectional network stream without reordering (TCP-like).
/// Reads are implemented using channels, writes go to the buffer inside VirtualConnection.
pub struct TCP {
net: Arc<NetworkTask>,
conn_id: usize,
dir: MessageDirection,
recv_chan: Chan<NetEvent>,
}
impl Debug for TCP {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TCP #{} ({})", self.conn_id, sender_str(self.dir),)
}
}
impl TCP {
/// Send a message to the other side. It's guaranteed that it will not arrive
/// before the arrival of all messages sent earlier.
pub fn send(&self, msg: AnyMessage) {
let conn = self.net.get(self.conn_id);
conn.send(&self.net, self.dir, msg);
}
/// Get a channel to receive incoming messages.
pub fn recv_chan(&self) -> Chan<NetEvent> {
self.recv_chan.clone()
}
pub fn connection_id(&self) -> usize {
self.conn_id
}
pub fn close(&self) {
let conn = self.net.get(self.conn_id);
conn.close(self.dir as usize);
}
}
struct Event {
time: u64,
conn_id: usize,
}
// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
// to get that.
impl PartialOrd for Event {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Event {
fn cmp(&self, other: &Self) -> Ordering {
(other.time, other.conn_id).cmp(&(self.time, self.conn_id))
}
}
impl PartialEq for Event {
fn eq(&self, other: &Self) -> bool {
(other.time, other.conn_id) == (self.time, self.conn_id)
}
}
impl Eq for Event {}

View File

@@ -1,54 +0,0 @@
use std::sync::Arc;
use rand::Rng;
use crate::proto::NodeEvent;
use super::{
chan::Chan,
network::TCP,
world::{Node, NodeId, World},
};
/// Abstraction with all functions (aka syscalls) available to the node.
#[derive(Clone)]
pub struct NodeOs {
world: Arc<World>,
internal: Arc<Node>,
}
impl NodeOs {
pub fn new(world: Arc<World>, internal: Arc<Node>) -> NodeOs {
NodeOs { world, internal }
}
/// Get the node id.
pub fn id(&self) -> NodeId {
self.internal.id
}
/// Opens a bidirectional connection with the other node. Always successful.
pub fn open_tcp(&self, dst: NodeId) -> TCP {
self.world.open_tcp(dst)
}
/// Returns a channel to receive node events (socket Accept and internal messages).
pub fn node_events(&self) -> Chan<NodeEvent> {
self.internal.node_events()
}
/// Get current time.
pub fn now(&self) -> u64 {
self.world.now()
}
/// Generate a random number in range [0, max).
pub fn random(&self, max: u64) -> u64 {
self.internal.rng.lock().gen_range(0..max)
}
/// Append a new event to the world event log.
pub fn log_event(&self, data: String) {
self.internal.log_event(data)
}
}

View File

@@ -1,50 +0,0 @@
use rand::{rngs::StdRng, Rng};
/// Describes random delays and failures. Delay will be uniformly distributed in [min, max].
/// Connection failure will occur with the probablity fail_prob.
#[derive(Clone, Debug)]
pub struct Delay {
pub min: u64,
pub max: u64,
pub fail_prob: f64, // [0; 1]
}
impl Delay {
/// Create a struct with no delay, no failures.
pub fn empty() -> Delay {
Delay {
min: 0,
max: 0,
fail_prob: 0.0,
}
}
/// Create a struct with a fixed delay.
pub fn fixed(ms: u64) -> Delay {
Delay {
min: ms,
max: ms,
fail_prob: 0.0,
}
}
/// Generate a random delay in range [min, max]. Return None if the
/// message should be dropped.
pub fn delay(&self, rng: &mut StdRng) -> Option<u64> {
if rng.gen_bool(self.fail_prob) {
return None;
}
Some(rng.gen_range(self.min..=self.max))
}
}
/// Describes network settings. All network packets will be subjected to the same delays and failures.
#[derive(Clone, Debug)]
pub struct NetworkOptions {
/// Connection will be automatically closed after this timeout if no data is received.
pub keepalive_timeout: Option<u64>,
/// New connections will be delayed by this amount of time.
pub connect_delay: Delay,
/// Each message will be delayed by this amount of time.
pub send_delay: Delay,
}

View File

@@ -1,63 +0,0 @@
use std::fmt::Debug;
use bytes::Bytes;
use utils::lsn::Lsn;
use crate::{network::TCP, world::NodeId};
/// Internal node events.
#[derive(Debug)]
pub enum NodeEvent {
Accept(TCP),
Internal(AnyMessage),
}
/// Events that are coming from a network socket.
#[derive(Clone, Debug)]
pub enum NetEvent {
Message(AnyMessage),
Closed,
}
/// Custom events generated throughout the simulation. Can be used by the test to verify the correctness.
#[derive(Debug)]
pub struct SimEvent {
pub time: u64,
pub node: NodeId,
pub data: String,
}
/// Umbrella type for all possible flavours of messages. These events can be sent over network
/// or to an internal node events channel.
#[derive(Clone)]
pub enum AnyMessage {
/// Not used, empty placeholder.
None,
/// Used internally for notifying node about new incoming connection.
InternalConnect,
Just32(u32),
ReplCell(ReplCell),
Bytes(Bytes),
LSN(u64),
}
impl Debug for AnyMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AnyMessage::None => write!(f, "None"),
AnyMessage::InternalConnect => write!(f, "InternalConnect"),
AnyMessage::Just32(v) => write!(f, "Just32({})", v),
AnyMessage::ReplCell(v) => write!(f, "ReplCell({:?})", v),
AnyMessage::Bytes(v) => write!(f, "Bytes({})", hex::encode(v)),
AnyMessage::LSN(v) => write!(f, "LSN({})", Lsn(*v)),
}
}
}
/// Used in reliable_copy_test.rs
#[derive(Clone, Debug)]
pub struct ReplCell {
pub value: u32,
pub client_id: u32,
pub seqno: u32,
}

View File

@@ -1,129 +0,0 @@
use std::{
cmp::Ordering,
collections::BinaryHeap,
ops::DerefMut,
sync::{
atomic::{AtomicU32, AtomicU64},
Arc,
},
};
use parking_lot::Mutex;
use tracing::trace;
use crate::executor::ThreadContext;
/// Holds current time and all pending wakeup events.
pub struct Timing {
/// Current world's time.
current_time: AtomicU64,
/// Pending timers.
queue: Mutex<BinaryHeap<Pending>>,
/// Global nonce. Makes picking events from binary heap queue deterministic
/// by appending a number to events with the same timestamp.
nonce: AtomicU32,
/// Used to schedule fake events.
fake_context: Arc<ThreadContext>,
}
impl Default for Timing {
fn default() -> Self {
Self::new()
}
}
impl Timing {
/// Create a new empty clock with time set to 0.
pub fn new() -> Timing {
Timing {
current_time: AtomicU64::new(0),
queue: Mutex::new(BinaryHeap::new()),
nonce: AtomicU32::new(0),
fake_context: Arc::new(ThreadContext::new()),
}
}
/// Return the current world's time.
pub fn now(&self) -> u64 {
self.current_time.load(std::sync::atomic::Ordering::SeqCst)
}
/// Tick-tock the global clock. Return the event ready to be processed
/// or move the clock forward and then return the event.
pub(crate) fn step(&self) -> Option<Arc<ThreadContext>> {
let mut queue = self.queue.lock();
if queue.is_empty() {
// no future events
return None;
}
if !self.is_event_ready(queue.deref_mut()) {
let next_time = queue.peek().unwrap().time;
self.current_time
.store(next_time, std::sync::atomic::Ordering::SeqCst);
trace!("rewind time to {}", next_time);
assert!(self.is_event_ready(queue.deref_mut()));
}
Some(queue.pop().unwrap().wake_context)
}
/// Append an event to the queue, to wakeup the thread in `ms` milliseconds.
pub(crate) fn schedule_wakeup(&self, ms: u64, wake_context: Arc<ThreadContext>) {
self.nonce.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let nonce = self.nonce.load(std::sync::atomic::Ordering::SeqCst);
self.queue.lock().push(Pending {
time: self.now() + ms,
nonce,
wake_context,
})
}
/// Append a fake event to the queue, to prevent clocks from skipping this time.
pub fn schedule_fake(&self, ms: u64) {
self.queue.lock().push(Pending {
time: self.now() + ms,
nonce: 0,
wake_context: self.fake_context.clone(),
});
}
/// Return true if there is a ready event.
fn is_event_ready(&self, queue: &mut BinaryHeap<Pending>) -> bool {
queue.peek().map_or(false, |x| x.time <= self.now())
}
/// Clear all pending events.
pub(crate) fn clear(&self) {
self.queue.lock().clear();
}
}
struct Pending {
time: u64,
nonce: u32,
wake_context: Arc<ThreadContext>,
}
// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
// to get that.
impl PartialOrd for Pending {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Pending {
fn cmp(&self, other: &Self) -> Ordering {
(other.time, other.nonce).cmp(&(self.time, self.nonce))
}
}
impl PartialEq for Pending {
fn eq(&self, other: &Self) -> bool {
(other.time, other.nonce) == (self.time, self.nonce)
}
}
impl Eq for Pending {}

View File

@@ -1,180 +0,0 @@
use parking_lot::Mutex;
use rand::{rngs::StdRng, SeedableRng};
use std::{
ops::DerefMut,
sync::{mpsc, Arc},
};
use crate::{
executor::{ExternalHandle, Runtime},
network::NetworkTask,
options::NetworkOptions,
proto::{NodeEvent, SimEvent},
time::Timing,
};
use super::{chan::Chan, network::TCP, node_os::NodeOs};
pub type NodeId = u32;
/// World contains simulation state.
pub struct World {
nodes: Mutex<Vec<Arc<Node>>>,
/// Random number generator.
rng: Mutex<StdRng>,
/// Internal event log.
events: Mutex<Vec<SimEvent>>,
/// Separate task that processes all network messages.
network_task: Arc<NetworkTask>,
/// Runtime for running threads and moving time.
runtime: Mutex<Runtime>,
/// To get current time.
timing: Arc<Timing>,
}
impl World {
pub fn new(seed: u64, options: Arc<NetworkOptions>) -> World {
let timing = Arc::new(Timing::new());
let mut runtime = Runtime::new(timing.clone());
let (tx, rx) = mpsc::channel();
runtime.spawn(move || {
// create and start network background thread, and send it back via the channel
NetworkTask::start_new(options, tx)
});
// wait for the network task to start
while runtime.step() {}
let network_task = rx.recv().unwrap();
World {
nodes: Mutex::new(Vec::new()),
rng: Mutex::new(StdRng::seed_from_u64(seed)),
events: Mutex::new(Vec::new()),
network_task,
runtime: Mutex::new(runtime),
timing,
}
}
pub fn step(&self) -> bool {
self.runtime.lock().step()
}
pub fn get_thread_step_count(&self) -> u64 {
self.runtime.lock().step_counter
}
/// Create a new random number generator.
pub fn new_rng(&self) -> StdRng {
let mut rng = self.rng.lock();
StdRng::from_rng(rng.deref_mut()).unwrap()
}
/// Create a new node.
pub fn new_node(self: &Arc<Self>) -> Arc<Node> {
let mut nodes = self.nodes.lock();
let id = nodes.len() as NodeId;
let node = Arc::new(Node::new(id, self.clone(), self.new_rng()));
nodes.push(node.clone());
node
}
/// Get an internal node state by id.
fn get_node(&self, id: NodeId) -> Option<Arc<Node>> {
let nodes = self.nodes.lock();
let num = id as usize;
if num < nodes.len() {
Some(nodes[num].clone())
} else {
None
}
}
pub fn stop_all(&self) {
self.runtime.lock().crash_all_threads();
}
/// Returns a writable end of a TCP connection, to send src->dst messages.
pub fn open_tcp(self: &Arc<World>, dst: NodeId) -> TCP {
// TODO: replace unwrap() with /dev/null socket.
let dst = self.get_node(dst).unwrap();
let dst_accept = dst.node_events.lock().clone();
let rng = self.new_rng();
self.network_task.start_new_connection(rng, dst_accept)
}
/// Get current time.
pub fn now(&self) -> u64 {
self.timing.now()
}
/// Get a copy of the internal clock.
pub fn clock(&self) -> Arc<Timing> {
self.timing.clone()
}
pub fn add_event(&self, node: NodeId, data: String) {
let time = self.now();
self.events.lock().push(SimEvent { time, node, data });
}
pub fn take_events(&self) -> Vec<SimEvent> {
let mut events = self.events.lock();
let mut res = Vec::new();
std::mem::swap(&mut res, &mut events);
res
}
pub fn deallocate(&self) {
self.stop_all();
self.timing.clear();
self.nodes.lock().clear();
}
}
/// Internal node state.
pub struct Node {
pub id: NodeId,
node_events: Mutex<Chan<NodeEvent>>,
world: Arc<World>,
pub(crate) rng: Mutex<StdRng>,
}
impl Node {
pub fn new(id: NodeId, world: Arc<World>, rng: StdRng) -> Node {
Node {
id,
node_events: Mutex::new(Chan::new()),
world,
rng: Mutex::new(rng),
}
}
/// Spawn a new thread with this node context.
pub fn launch(self: &Arc<Self>, f: impl FnOnce(NodeOs) + Send + 'static) -> ExternalHandle {
let node = self.clone();
let world = self.world.clone();
self.world.runtime.lock().spawn(move || {
f(NodeOs::new(world, node.clone()));
})
}
/// Returns a channel to receive Accepts and internal messages.
pub fn node_events(&self) -> Chan<NodeEvent> {
self.node_events.lock().clone()
}
/// This will drop all in-flight Accept messages.
pub fn replug_node_events(&self, chan: Chan<NodeEvent>) {
*self.node_events.lock() = chan;
}
/// Append event to the world's log.
pub fn log_event(&self, data: String) {
self.world.add_event(self.id, data)
}
}

View File

@@ -1,244 +0,0 @@
//! Simple test to verify that simulator is working.
#[cfg(test)]
mod reliable_copy_test {
use anyhow::Result;
use desim::executor::{self, PollSome};
use desim::options::{Delay, NetworkOptions};
use desim::proto::{NetEvent, NodeEvent, ReplCell};
use desim::world::{NodeId, World};
use desim::{node_os::NodeOs, proto::AnyMessage};
use parking_lot::Mutex;
use std::sync::Arc;
use tracing::info;
/// Disk storage trait and implementation.
pub trait Storage<T> {
fn flush_pos(&self) -> u32;
fn flush(&mut self) -> Result<()>;
fn write(&mut self, t: T);
}
#[derive(Clone)]
pub struct SharedStorage<T> {
pub state: Arc<Mutex<InMemoryStorage<T>>>,
}
impl<T> SharedStorage<T> {
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(InMemoryStorage::new())),
}
}
}
impl<T> Storage<T> for SharedStorage<T> {
fn flush_pos(&self) -> u32 {
self.state.lock().flush_pos
}
fn flush(&mut self) -> Result<()> {
executor::yield_me(0);
self.state.lock().flush()
}
fn write(&mut self, t: T) {
executor::yield_me(0);
self.state.lock().write(t);
}
}
pub struct InMemoryStorage<T> {
pub data: Vec<T>,
pub flush_pos: u32,
}
impl<T> InMemoryStorage<T> {
pub fn new() -> Self {
Self {
data: Vec::new(),
flush_pos: 0,
}
}
pub fn flush(&mut self) -> Result<()> {
self.flush_pos = self.data.len() as u32;
Ok(())
}
pub fn write(&mut self, t: T) {
self.data.push(t);
}
}
/// Server implementation.
pub fn run_server(os: NodeOs, mut storage: Box<dyn Storage<u32>>) {
info!("started server");
let node_events = os.node_events();
let mut epoll_vec: Vec<Box<dyn PollSome>> = vec![Box::new(node_events.clone())];
let mut sockets = vec![];
loop {
let index = executor::epoll_chans(&epoll_vec, -1).unwrap();
if index == 0 {
let node_event = node_events.must_recv();
info!("got node event: {:?}", node_event);
if let NodeEvent::Accept(tcp) = node_event {
tcp.send(AnyMessage::Just32(storage.flush_pos()));
epoll_vec.push(Box::new(tcp.recv_chan()));
sockets.push(tcp);
}
continue;
}
let recv_chan = sockets[index - 1].recv_chan();
let socket = &sockets[index - 1];
let event = recv_chan.must_recv();
info!("got event: {:?}", event);
if let NetEvent::Message(AnyMessage::ReplCell(cell)) = event {
if cell.seqno != storage.flush_pos() {
info!("got out of order data: {:?}", cell);
continue;
}
storage.write(cell.value);
storage.flush().unwrap();
socket.send(AnyMessage::Just32(storage.flush_pos()));
}
}
}
/// Client copies all data from array to the remote node.
pub fn run_client(os: NodeOs, data: &[ReplCell], dst: NodeId) {
info!("started client");
let mut delivered = 0;
let mut sock = os.open_tcp(dst);
let mut recv_chan = sock.recv_chan();
while delivered < data.len() {
let num = &data[delivered];
info!("sending data: {:?}", num.clone());
sock.send(AnyMessage::ReplCell(num.clone()));
// loop {
let event = recv_chan.recv();
match event {
NetEvent::Message(AnyMessage::Just32(flush_pos)) => {
if flush_pos == 1 + delivered as u32 {
delivered += 1;
}
}
NetEvent::Closed => {
info!("connection closed, reestablishing");
sock = os.open_tcp(dst);
recv_chan = sock.recv_chan();
}
_ => {}
}
// }
}
let sock = os.open_tcp(dst);
for num in data {
info!("sending data: {:?}", num.clone());
sock.send(AnyMessage::ReplCell(num.clone()));
}
info!("sent all data and finished client");
}
/// Run test simulations.
#[test]
fn sim_example_reliable_copy() {
utils::logging::init(
utils::logging::LogFormat::Test,
utils::logging::TracingErrorLayerEnablement::Disabled,
utils::logging::Output::Stdout,
)
.expect("logging init failed");
let delay = Delay {
min: 1,
max: 60,
fail_prob: 0.4,
};
let network = NetworkOptions {
keepalive_timeout: Some(50),
connect_delay: delay.clone(),
send_delay: delay.clone(),
};
for seed in 0..20 {
let u32_data: [u32; 5] = [1, 2, 3, 4, 5];
let data = u32_to_cells(&u32_data, 1);
let world = Arc::new(World::new(seed, Arc::new(network.clone())));
start_simulation(Options {
world,
time_limit: 1_000_000,
client_fn: Box::new(move |os, server_id| run_client(os, &data, server_id)),
u32_data,
});
}
}
pub struct Options {
pub world: Arc<World>,
pub time_limit: u64,
pub u32_data: [u32; 5],
pub client_fn: Box<dyn FnOnce(NodeOs, u32) + Send + 'static>,
}
pub fn start_simulation(options: Options) {
let world = options.world;
let client_node = world.new_node();
let server_node = world.new_node();
let server_id = server_node.id;
// start the client thread
client_node.launch(move |os| {
let client_fn = options.client_fn;
client_fn(os, server_id);
});
// start the server thread
let shared_storage = SharedStorage::new();
let server_storage = shared_storage.clone();
server_node.launch(move |os| run_server(os, Box::new(server_storage)));
while world.step() && world.now() < options.time_limit {}
let disk_data = shared_storage.state.lock().data.clone();
assert!(verify_data(&disk_data, &options.u32_data[..]));
}
pub fn u32_to_cells(data: &[u32], client_id: u32) -> Vec<ReplCell> {
let mut res = Vec::new();
for (i, _) in data.iter().enumerate() {
res.push(ReplCell {
client_id,
seqno: i as u32,
value: data[i],
});
}
res
}
fn verify_data(disk_data: &[u32], data: &[u32]) -> bool {
if disk_data.len() != data.len() {
return false;
}
for i in 0..data.len() {
if disk_data[i] != data[i] {
return false;
}
}
true
}
}

View File

@@ -494,8 +494,6 @@ pub struct TimelineInfo {
pub current_logical_size: u64,
pub current_logical_size_is_accurate: bool,
pub directory_entries_counts: Vec<u64>,
/// Sum of the size of all layer files.
/// If a layer is present in both local FS and S3, it counts only once.
pub current_physical_size: Option<u64>, // is None when timeline is Unloaded

View File

@@ -124,7 +124,6 @@ impl RelTag {
Ord,
strum_macros::EnumIter,
strum_macros::FromRepr,
enum_map::Enum,
)]
#[repr(u8)]
pub enum SlruKind {

View File

@@ -431,11 +431,11 @@ pub fn generate_wal_segment(segno: u64, system_id: u64, lsn: Lsn) -> Result<Byte
#[repr(C)]
#[derive(Serialize)]
pub struct XlLogicalMessage {
pub db_id: Oid,
pub transactional: uint32, // bool, takes 4 bytes due to alignment in C structures
pub prefix_size: uint64,
pub message_size: uint64,
struct XlLogicalMessage {
db_id: Oid,
transactional: uint32, // bool, takes 4 bytes due to alignment in C structures
prefix_size: uint64,
message_size: uint64,
}
impl XlLogicalMessage {

View File

@@ -13,6 +13,5 @@ rand.workspace = true
tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true
serde.workspace = true
workspace_hack.workspace = true

View File

@@ -7,7 +7,6 @@ pub mod framed;
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
// re-export for use in utils pageserver_feedback.rs
@@ -124,7 +123,7 @@ impl StartupMessageParams {
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct CancelKeyData {
pub backend_pid: i32,
pub cancel_key: i32,

View File

@@ -1,7 +1,7 @@
use std::{
borrow::Cow,
fs::{self, File},
io::{self, Write},
io,
};
use camino::{Utf8Path, Utf8PathBuf};
@@ -161,48 +161,6 @@ pub async fn durable_rename(
Ok(())
}
/// Writes a file to the specified `final_path` in a crash safe fasion, using [`std::fs`].
///
/// The file is first written to the specified `tmp_path`, and in a second
/// step, the `tmp_path` is renamed to the `final_path`. Intermediary fsync
/// and atomic rename guarantee that, if we crash at any point, there will never
/// be a partially written file at `final_path` (but maybe at `tmp_path`).
///
/// Callers are responsible for serializing calls of this function for a given `final_path`.
/// If they don't, there may be an error due to conflicting `tmp_path`, or there will
/// be no error and the content of `final_path` will be the "winner" caller's `content`.
/// I.e., the atomticity guarantees still hold.
pub fn overwrite(
final_path: &Utf8Path,
tmp_path: &Utf8Path,
content: &[u8],
) -> std::io::Result<()> {
let Some(final_path_parent) = final_path.parent() else {
return Err(std::io::Error::from_raw_os_error(
nix::errno::Errno::EINVAL as i32,
));
};
std::fs::remove_file(tmp_path).or_else(crate::fs_ext::ignore_not_found)?;
let mut file = std::fs::OpenOptions::new()
.write(true)
// Use `create_new` so that, if we race with ourselves or something else,
// we bail out instead of causing damage.
.create_new(true)
.open(tmp_path)?;
file.write_all(content)?;
file.sync_all()?;
drop(file); // don't keep the fd open for longer than we have to
std::fs::rename(tmp_path, final_path)?;
let final_parent_dirfd = std::fs::OpenOptions::new()
.read(true)
.open(final_path_parent)?;
final_parent_dirfd.sync_all()?;
Ok(())
}
#[cfg(test)]
mod tests {

View File

@@ -54,10 +54,12 @@ impl Generation {
}
#[track_caller]
pub fn get_suffix(&self) -> impl std::fmt::Display {
pub fn get_suffix(&self) -> String {
match self {
Self::Valid(v) => GenerationFileSuffix(Some(*v)),
Self::None => GenerationFileSuffix(None),
Self::Valid(v) => {
format!("-{:08x}", v)
}
Self::None => "".into(),
Self::Broken => {
panic!("Tried to use a broken generation");
}
@@ -88,7 +90,6 @@ impl Generation {
}
}
#[track_caller]
pub fn next(&self) -> Generation {
match self {
Self::Valid(n) => Self::Valid(*n + 1),
@@ -106,18 +107,6 @@ impl Generation {
}
}
struct GenerationFileSuffix(Option<u32>);
impl std::fmt::Display for GenerationFileSuffix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(g) = self.0 {
write!(f, "-{g:08x}")
} else {
Ok(())
}
}
}
impl Serialize for Generation {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@@ -175,24 +164,4 @@ mod test {
assert!(Generation::none() < Generation::new(0));
assert!(Generation::none() < Generation::new(1));
}
#[test]
fn suffix_is_stable() {
use std::fmt::Write as _;
// the suffix must remain stable through-out the pageserver remote storage evolution and
// not be changed accidentially without thinking about migration
let examples = [
(line!(), Generation::None, ""),
(line!(), Generation::Valid(0), "-00000000"),
(line!(), Generation::Valid(u32::MAX), "-ffffffff"),
];
let mut s = String::new();
for (line, gen, expected) in examples {
s.clear();
write!(s, "{}", &gen.get_suffix()).expect("string grows");
assert_eq!(s, expected, "example on {line}");
}
}
}

View File

@@ -69,44 +69,37 @@ impl<T> OnceCell<T> {
F: FnOnce(InitPermit) -> Fut,
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
{
loop {
let sem = {
let guard = self.inner.lock().unwrap();
if guard.value.is_some() {
return Ok(Guard(guard));
}
guard.init_semaphore.clone()
};
{
let permit = {
// increment the count for the duration of queued
let _guard = CountWaitingInitializers::start(self);
sem.acquire().await
};
let Ok(permit) = permit else {
let guard = self.inner.lock().unwrap();
if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
// there was a take_and_deinit in between
continue;
}
assert!(
guard.value.is_some(),
"semaphore got closed, must be initialized"
);
return Ok(Guard(guard));
};
permit.forget();
}
let permit = InitPermit(sem);
let (value, _permit) = factory(permit).await?;
let sem = {
let guard = self.inner.lock().unwrap();
if guard.value.is_some() {
return Ok(Guard(guard));
}
guard.init_semaphore.clone()
};
return Ok(Self::set0(value, guard));
let permit = {
// increment the count for the duration of queued
let _guard = CountWaitingInitializers::start(self);
sem.acquire_owned().await
};
match permit {
Ok(permit) => {
let permit = InitPermit(permit);
let (value, _permit) = factory(permit).await?;
let guard = self.inner.lock().unwrap();
Ok(Self::set0(value, guard))
}
Err(_closed) => {
let guard = self.inner.lock().unwrap();
assert!(
guard.value.is_some(),
"semaphore got closed, must be initialized"
);
return Ok(Guard(guard));
}
}
}
@@ -204,41 +197,27 @@ impl<'a, T> Guard<'a, T> {
/// [`OnceCell::get_or_init`] will wait on it to complete.
pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
let mut swapped = Inner::default();
let sem = swapped.init_semaphore.clone();
// acquire and forget right away, moving the control over to InitPermit
sem.try_acquire().expect("we just created this").forget();
let permit = swapped
.init_semaphore
.clone()
.try_acquire_owned()
.expect("we just created this");
std::mem::swap(&mut *self.0, &mut swapped);
swapped
.value
.map(|v| (v, InitPermit(sem)))
.map(|v| (v, InitPermit(permit)))
.expect("guard is not created unless value has been initialized")
}
}
/// Type held by OnceCell (de)initializing task.
///
/// On drop, this type will return the permit.
pub struct InitPermit(Arc<tokio::sync::Semaphore>);
impl Drop for InitPermit {
fn drop(&mut self) {
assert_eq!(
self.0.available_permits(),
0,
"InitPermit should only exist as the unique permit"
);
self.0.add_permits(1);
}
}
pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
#[cfg(test)]
mod tests {
use futures::Future;
use super::*;
use std::{
convert::Infallible,
pin::{pin, Pin},
sync::atomic::{AtomicUsize, Ordering},
time::Duration,
};
@@ -401,85 +380,4 @@ mod tests {
.unwrap();
assert_eq!(*g, "now initialized");
}
#[tokio::test(start_paused = true)]
async fn reproduce_init_take_deinit_race() {
init_take_deinit_scenario(|cell, factory| {
Box::pin(async {
cell.get_or_init(factory).await.unwrap();
})
})
.await;
}
type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
/// Reproduce an assertion failure.
///
/// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
/// We currently only have one, but the structure is kept.
async fn init_take_deinit_scenario<F>(init_way: F)
where
F: for<'a> Fn(
&'a OnceCell<&'static str>,
BoxedInitFunction<&'static str, Infallible>,
) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
{
let cell = OnceCell::default();
// acquire the init_semaphore only permit to drive initializing tasks in order to waiting
// on the same semaphore.
let permit = cell
.inner
.lock()
.unwrap()
.init_semaphore
.clone()
.try_acquire_owned()
.unwrap();
let mut t1 = pin!(init_way(
&cell,
Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
));
let mut t2 = pin!(init_way(
&cell,
Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
));
// drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
// no longer make progress
tokio::select! {
_ = &mut t2 => unreachable!("it cannot get permit"),
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
}
// followed by t1 in the init_semaphore
tokio::select! {
_ = &mut t1 => unreachable!("it cannot get permit"),
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
}
// now let t2 proceed and initialize
drop(permit);
t2.await;
let (s, permit) = { cell.get().unwrap().take_and_deinit() };
assert_eq!("t2", s);
// now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
// the new one.
tokio::select! {
_ = &mut t1 => unreachable!("it cannot get permit"),
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
}
// only now we get to initialize it
drop(permit);
t1.await;
assert_eq!("t1", *cell.get().unwrap());
}
}

View File

@@ -34,9 +34,6 @@ fn main() -> anyhow::Result<()> {
println!("cargo:rustc-link-lib=static=walproposer");
println!("cargo:rustc-link-search={walproposer_lib_search_str}");
// Rebuild crate when libwalproposer.a changes
println!("cargo:rerun-if-changed={walproposer_lib_search_str}/libwalproposer.a");
let pg_config_bin = pg_install_abs.join("v16").join("bin").join("pg_config");
let inc_server_path: String = if pg_config_bin.exists() {
let output = Command::new(pg_config_bin)
@@ -82,7 +79,6 @@ fn main() -> anyhow::Result<()> {
.allowlist_function("WalProposerBroadcast")
.allowlist_function("WalProposerPoll")
.allowlist_function("WalProposerFree")
.allowlist_function("SafekeeperStateDesiredEvents")
.allowlist_var("DEBUG5")
.allowlist_var("DEBUG4")
.allowlist_var("DEBUG3")

View File

@@ -22,7 +22,6 @@ use crate::bindings::WalProposerExecStatusType;
use crate::bindings::WalproposerShmemState;
use crate::bindings::XLogRecPtr;
use crate::walproposer::ApiImpl;
use crate::walproposer::StreamingCallback;
use crate::walproposer::WaitResult;
extern "C" fn get_shmem_state(wp: *mut WalProposer) -> *mut WalproposerShmemState {
@@ -37,8 +36,7 @@ extern "C" fn start_streaming(wp: *mut WalProposer, startpos: XLogRecPtr) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
let callback = StreamingCallback::new(wp);
(*api).start_streaming(startpos, &callback);
(*api).start_streaming(startpos)
}
}
@@ -136,18 +134,19 @@ extern "C" fn conn_async_read(
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
let (res, result) = (*api).conn_async_read(&mut (*sk));
// This function has guarantee that returned buf will be valid until
// the next call. So we can store a Vec in each Safekeeper and reuse
// it on the next call.
let mut inbuf = take_vec_u8(&mut (*sk).inbuf).unwrap_or_default();
inbuf.clear();
let result = (*api).conn_async_read(&mut (*sk), &mut inbuf);
inbuf.clear();
inbuf.extend_from_slice(res);
// Put a Vec back to sk->inbuf and return data ptr.
*amount = inbuf.len() as i32;
*buf = store_vec_u8(&mut (*sk).inbuf, inbuf);
*amount = res.len() as i32;
result
}
@@ -183,10 +182,6 @@ extern "C" fn recovery_download(wp: *mut WalProposer, sk: *mut Safekeeper) -> bo
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
// currently `recovery_download` is always called right after election
(*api).after_election(&mut (*wp));
(*api).recovery_download(&mut (*wp), &mut (*sk))
}
}
@@ -282,8 +277,7 @@ extern "C" fn wait_event_set(
}
WaitResult::Timeout => {
*event_sk = std::ptr::null_mut();
// WaitEventSetWait returns 0 for timeout.
*events = 0;
*events = crate::bindings::WL_TIMEOUT;
0
}
WaitResult::Network(sk, event_mask) => {
@@ -346,7 +340,7 @@ extern "C" fn log_internal(
}
}
#[derive(Debug, PartialEq)]
#[derive(Debug)]
pub enum Level {
Debug5,
Debug4,

View File

@@ -1,13 +1,13 @@
use std::ffi::CString;
use postgres_ffi::WAL_SEGMENT_SIZE;
use utils::{id::TenantTimelineId, lsn::Lsn};
use utils::id::TenantTimelineId;
use crate::{
api_bindings::{create_api, take_vec_u8, Level},
bindings::{
NeonWALReadResult, Safekeeper, WalProposer, WalProposerBroadcast, WalProposerConfig,
WalProposerCreate, WalProposerFree, WalProposerPoll, WalProposerStart,
NeonWALReadResult, Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate,
WalProposerFree, WalProposerStart,
},
};
@@ -16,11 +16,11 @@ use crate::{
///
/// Refer to `pgxn/neon/walproposer.h` for documentation.
pub trait ApiImpl {
fn get_shmem_state(&self) -> *mut crate::bindings::WalproposerShmemState {
fn get_shmem_state(&self) -> &mut crate::bindings::WalproposerShmemState {
todo!()
}
fn start_streaming(&self, _startpos: u64, _callback: &StreamingCallback) {
fn start_streaming(&self, _startpos: u64) {
todo!()
}
@@ -70,11 +70,7 @@ pub trait ApiImpl {
todo!()
}
fn conn_async_read(
&self,
_sk: &mut Safekeeper,
_vec: &mut Vec<u8>,
) -> crate::bindings::PGAsyncReadResult {
fn conn_async_read(&self, _sk: &mut Safekeeper) -> (&[u8], crate::bindings::PGAsyncReadResult) {
todo!()
}
@@ -155,14 +151,12 @@ pub trait ApiImpl {
}
}
#[derive(Debug)]
pub enum WaitResult {
Latch,
Timeout,
Network(*mut Safekeeper, u32),
}
#[derive(Clone)]
pub struct Config {
/// Tenant and timeline id
pub ttid: TenantTimelineId,
@@ -248,24 +242,6 @@ impl Drop for Wrapper {
}
}
pub struct StreamingCallback {
wp: *mut WalProposer,
}
impl StreamingCallback {
pub fn new(wp: *mut WalProposer) -> StreamingCallback {
StreamingCallback { wp }
}
pub fn broadcast(&self, startpos: Lsn, endpos: Lsn) {
unsafe { WalProposerBroadcast(self.wp, startpos.0, endpos.0) }
}
pub fn poll(&self) {
unsafe { WalProposerPoll(self.wp) }
}
}
#[cfg(test)]
mod tests {
use core::panic;
@@ -368,13 +344,14 @@ mod tests {
fn conn_async_read(
&self,
_: &mut crate::bindings::Safekeeper,
vec: &mut Vec<u8>,
) -> crate::bindings::PGAsyncReadResult {
) -> (&[u8], crate::bindings::PGAsyncReadResult) {
println!("conn_async_read");
let reply = self.next_safekeeper_reply();
println!("conn_async_read result: {:?}", reply);
vec.extend_from_slice(reply);
crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS
(
reply,
crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS,
)
}
fn conn_blocking_write(&self, _: &mut crate::bindings::Safekeeper, buf: &[u8]) -> bool {

View File

@@ -7,14 +7,13 @@
//! logging what happens when a sequential scan is requested on a small table, then picking out two
//! suitable from logs.
use std::sync::Arc;
use std::sync::{Arc, Barrier};
use bytes::{Buf, Bytes};
use pageserver::{
config::PageServerConf, repository::Key, walrecord::NeonWalRecord, walredo::PostgresRedoManager,
};
use pageserver_api::shard::TenantShardId;
use tokio_util::sync::CancellationToken;
use utils::{id::TenantId, lsn::Lsn};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
@@ -40,11 +39,11 @@ fn redo_scenarios(c: &mut Criterion) {
.build()
.unwrap();
tracing::info!("executing first");
rt.block_on(short().execute(&manager)).unwrap();
short().execute(rt.handle(), &manager).unwrap();
tracing::info!("first executed");
}
let thread_counts = [1, 2, 4, 8, 16, 32, 64, 128];
let thread_counts = [1, 2, 4, 8, 16];
let mut group = c.benchmark_group("short");
group.sampling_mode(criterion::SamplingMode::Flat);
@@ -75,85 +74,116 @@ fn redo_scenarios(c: &mut Criterion) {
drop(group);
}
/// Sets up a multi-threaded tokio runtime with default worker thread count,
/// then, spawn `requesters` tasks that repeatedly:
/// - get input from `input_factor()`
/// - call `manager.request_redo()` with their input
///
/// This stress-tests the scalability of a single walredo manager at high tokio-level concurrency.
///
/// Using tokio's default worker thread count means the results will differ on machines
/// with different core countrs. We don't care about that, the performance will always
/// be different on different hardware. To compare performance of different software versions,
/// use the same hardware.
/// Sets up `threads` number of requesters to `request_redo`, with the given input.
fn add_multithreaded_walredo_requesters(
b: &mut criterion::Bencher,
requesters: usize,
threads: u32,
manager: &Arc<PostgresRedoManager>,
input_factory: fn() -> Request,
) {
assert_ne!(requesters, 0);
assert_ne!(threads, 0);
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
if threads == 1 {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let handle = rt.handle();
b.iter_batched_ref(
|| Some(input_factory()),
|input| execute_all(input.take(), handle, manager),
criterion::BatchSize::PerIteration,
);
} else {
let (work_tx, work_rx) = std::sync::mpsc::sync_channel(threads as usize);
let cancel = CancellationToken::new();
let work_rx = std::sync::Arc::new(std::sync::Mutex::new(work_rx));
let barrier = Arc::new(tokio::sync::Barrier::new(requesters + 1));
let barrier = Arc::new(Barrier::new(threads as usize + 1));
let jhs = (0..requesters)
.map(|_| {
let manager = manager.clone();
let barrier = barrier.clone();
let cancel = cancel.clone();
rt.spawn(async move {
let work_loop = async move {
loop {
let input = input_factory();
barrier.wait().await;
let page = input.execute(&manager).await.unwrap();
assert_eq!(page.remaining(), 8192);
barrier.wait().await;
let jhs = (0..threads)
.map(|_| {
std::thread::spawn({
let manager = manager.clone();
let barrier = barrier.clone();
let work_rx = work_rx.clone();
move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let handle = rt.handle();
loop {
// queue up and wait if we want to go another round
if work_rx.lock().unwrap().recv().is_err() {
break;
}
let input = Some(input_factory());
barrier.wait();
execute_all(input, handle, &manager).unwrap();
barrier.wait();
}
}
};
tokio::select! {
_ = work_loop => {},
_ = cancel.cancelled() => { }
}
})
})
})
.collect::<Vec<_>>();
.collect::<Vec<_>>();
let _jhs = JoinOnDrop(jhs);
b.iter_batched(
|| {
for _ in 0..threads {
work_tx.send(()).unwrap()
}
},
|()| {
// start the work
barrier.wait();
b.iter_batched(
|| (),
|()| {
// start the work
rt.block_on(async {
barrier.wait().await;
// wait for work to complete
barrier.wait().await;
});
},
criterion::BatchSize::PerIteration,
);
barrier.wait();
},
criterion::BatchSize::PerIteration,
);
cancel.cancel();
let jhs = JoinOnDrop(jhs);
rt.block_on(jhs.join_all());
drop(work_tx);
}
}
struct JoinOnDrop(Vec<tokio::task::JoinHandle<()>>);
struct JoinOnDrop(Vec<std::thread::JoinHandle<()>>);
impl JoinOnDrop {
/// Checks all the futures for panics.
pub async fn join_all(self) {
futures::future::try_join_all(self.0).await.unwrap();
impl Drop for JoinOnDrop {
// it's not really needless because we want join all then check for panicks
#[allow(clippy::needless_collect)]
fn drop(&mut self) {
// first join all
let results = self.0.drain(..).map(|jh| jh.join()).collect::<Vec<_>>();
// then check the results; panicking here is not great, but it does get the message across
// to the user, and sets an exit value.
results.into_iter().try_for_each(|res| res).unwrap();
}
}
fn execute_all<I>(
input: I,
handle: &tokio::runtime::Handle,
manager: &PostgresRedoManager,
) -> anyhow::Result<()>
where
I: IntoIterator<Item = Request>,
{
// just fire all requests as fast as possible
input.into_iter().try_for_each(|req| {
let page = req.execute(handle, manager)?;
assert_eq!(page.remaining(), 8192);
anyhow::Ok(())
})
}
criterion_group!(benches, redo_scenarios);
criterion_main!(benches);
@@ -463,7 +493,11 @@ struct Request {
}
impl Request {
async fn execute(self, manager: &PostgresRedoManager) -> anyhow::Result<Bytes> {
fn execute(
self,
rt: &tokio::runtime::Handle,
manager: &PostgresRedoManager,
) -> anyhow::Result<Bytes> {
let Request {
key,
lsn,
@@ -472,8 +506,6 @@ impl Request {
pg_version,
} = self;
manager
.request_redo(key, lsn, base_img, records, pg_version)
.await
rt.block_on(manager.request_redo(key, lsn, base_img, records, pg_version))
}
}

View File

@@ -234,7 +234,7 @@ impl DeletionHeader {
let header_bytes = serde_json::to_vec(self).context("serialize deletion header")?;
let header_path = conf.deletion_header_path();
let temp_path = path_with_suffix_extension(&header_path, TEMP_SUFFIX);
VirtualFile::crashsafe_overwrite(header_path, temp_path, header_bytes)
VirtualFile::crashsafe_overwrite(&header_path, &temp_path, &header_bytes)
.await
.maybe_fatal_err("save deletion header")?;
@@ -325,8 +325,7 @@ impl DeletionList {
let temp_path = path_with_suffix_extension(&path, TEMP_SUFFIX);
let bytes = serde_json::to_vec(self).expect("Failed to serialize deletion list");
VirtualFile::crashsafe_overwrite(path, temp_path, bytes)
VirtualFile::crashsafe_overwrite(&path, &temp_path, &bytes)
.await
.maybe_fatal_err("save deletion list")
.map_err(Into::into)

View File

@@ -422,7 +422,6 @@ async fn build_timeline_info_common(
tenant::timeline::logical_size::Accuracy::Approximate => false,
tenant::timeline::logical_size::Accuracy::Exact => true,
},
directory_entries_counts: timeline.get_directory_metrics().to_vec(),
current_physical_size,
current_logical_size_non_incremental: None,
timeline_dir_layer_file_size_sum: None,
@@ -489,9 +488,7 @@ async fn timeline_create_handler(
let state = get_state(&request);
async {
let tenant = state
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id, false)?;
let tenant = state.tenant_manager.get_attached_tenant_shard(tenant_shard_id, false)?;
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
@@ -501,62 +498,48 @@ async fn timeline_create_handler(
tracing::info!("bootstrapping");
}
match tenant
.create_timeline(
new_timeline_id,
request_data.ancestor_timeline_id,
request_data.ancestor_start_lsn,
request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION),
request_data.existing_initdb_timeline_id,
state.broker_client.clone(),
&ctx,
)
.await
{
match tenant.create_timeline(
new_timeline_id,
request_data.ancestor_timeline_id.map(TimelineId::from),
request_data.ancestor_start_lsn,
request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION),
request_data.existing_initdb_timeline_id,
state.broker_client.clone(),
&ctx,
)
.await {
Ok(new_timeline) => {
// Created. Construct a TimelineInfo for it.
let timeline_info = build_timeline_info_common(
&new_timeline,
&ctx,
tenant::timeline::GetLogicalSizePriority::User,
)
.await
.map_err(ApiError::InternalServerError)?;
let timeline_info = build_timeline_info_common(&new_timeline, &ctx, tenant::timeline::GetLogicalSizePriority::User)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::CREATED, timeline_info)
}
Err(_) if tenant.cancel.is_cancelled() => {
// In case we get some ugly error type during shutdown, cast it into a clean 503.
json_response(
StatusCode::SERVICE_UNAVAILABLE,
HttpErrorBody::from_msg("Tenant shutting down".to_string()),
)
json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg("Tenant shutting down".to_string()))
}
Err(tenant::CreateTimelineError::Conflict | tenant::CreateTimelineError::AlreadyCreating) => {
json_response(StatusCode::CONFLICT, ())
}
Err(tenant::CreateTimelineError::AncestorLsn(err)) => {
json_response(StatusCode::NOT_ACCEPTABLE, HttpErrorBody::from_msg(
format!("{err:#}")
))
}
Err(e @ tenant::CreateTimelineError::AncestorNotActive) => {
json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg(e.to_string()))
}
Err(tenant::CreateTimelineError::ShuttingDown) => {
json_response(StatusCode::SERVICE_UNAVAILABLE,HttpErrorBody::from_msg("tenant shutting down".to_string()))
}
Err(
tenant::CreateTimelineError::Conflict
| tenant::CreateTimelineError::AlreadyCreating,
) => json_response(StatusCode::CONFLICT, ()),
Err(tenant::CreateTimelineError::AncestorLsn(err)) => json_response(
StatusCode::NOT_ACCEPTABLE,
HttpErrorBody::from_msg(format!("{err:#}")),
),
Err(e @ tenant::CreateTimelineError::AncestorNotActive) => json_response(
StatusCode::SERVICE_UNAVAILABLE,
HttpErrorBody::from_msg(e.to_string()),
),
Err(tenant::CreateTimelineError::ShuttingDown) => json_response(
StatusCode::SERVICE_UNAVAILABLE,
HttpErrorBody::from_msg("tenant shutting down".to_string()),
),
Err(tenant::CreateTimelineError::Other(err)) => Err(ApiError::InternalServerError(err)),
}
}
.instrument(info_span!("timeline_create",
tenant_id = %tenant_shard_id.tenant_id,
shard_id = %tenant_shard_id.shard_slug(),
timeline_id = %new_timeline_id,
lsn=?request_data.ancestor_start_lsn,
pg_version=?request_data.pg_version
))
timeline_id = %new_timeline_id, lsn=?request_data.ancestor_start_lsn, pg_version=?request_data.pg_version))
.await
}

View File

@@ -602,15 +602,6 @@ pub(crate) mod initial_logical_size {
});
}
static DIRECTORY_ENTRIES_COUNT: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_directory_entries_count",
"Sum of the entries in pageserver-stored directory listings",
&["tenant_id", "shard_id", "timeline_id"]
)
.expect("failed to define a metric")
});
pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_tenant_states_count",
@@ -1818,7 +1809,6 @@ pub(crate) struct TimelineMetrics {
resident_physical_size_gauge: UIntGauge,
/// copy of LayeredTimeline.current_logical_size
pub current_logical_size_gauge: UIntGauge,
pub directory_entries_count_gauge: Lazy<UIntGauge, Box<dyn Send + Fn() -> UIntGauge>>,
pub num_persistent_files_created: IntCounter,
pub persistent_bytes_written: IntCounter,
pub evictions: IntCounter,
@@ -1828,12 +1818,12 @@ pub(crate) struct TimelineMetrics {
impl TimelineMetrics {
pub fn new(
tenant_shard_id: &TenantShardId,
timeline_id_raw: &TimelineId,
timeline_id: &TimelineId,
evictions_with_low_residence_duration_builder: EvictionsWithLowResidenceDurationBuilder,
) -> Self {
let tenant_id = tenant_shard_id.tenant_id.to_string();
let shard_id = format!("{}", tenant_shard_id.shard_slug());
let timeline_id = timeline_id_raw.to_string();
let timeline_id = timeline_id.to_string();
let flush_time_histo = StorageTimeMetrics::new(
StorageTimeOperation::LayerFlush,
&tenant_id,
@@ -1886,22 +1876,6 @@ impl TimelineMetrics {
let current_logical_size_gauge = CURRENT_LOGICAL_SIZE
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
.unwrap();
// TODO use impl Trait syntax here once we have ability to use it: https://github.com/rust-lang/rust/issues/63065
let directory_entries_count_gauge_closure = {
let tenant_shard_id = *tenant_shard_id;
let timeline_id_raw = *timeline_id_raw;
move || {
let tenant_id = tenant_shard_id.tenant_id.to_string();
let shard_id = format!("{}", tenant_shard_id.shard_slug());
let timeline_id = timeline_id_raw.to_string();
let gauge: UIntGauge = DIRECTORY_ENTRIES_COUNT
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
.unwrap();
gauge
}
};
let directory_entries_count_gauge: Lazy<UIntGauge, Box<dyn Send + Fn() -> UIntGauge>> =
Lazy::new(Box::new(directory_entries_count_gauge_closure));
let num_persistent_files_created = NUM_PERSISTENT_FILES_CREATED
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
.unwrap();
@@ -1928,7 +1902,6 @@ impl TimelineMetrics {
last_record_gauge,
resident_physical_size_gauge,
current_logical_size_gauge,
directory_entries_count_gauge,
num_persistent_files_created,
persistent_bytes_written,
evictions,
@@ -1971,9 +1944,6 @@ impl Drop for TimelineMetrics {
RESIDENT_PHYSICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
}
let _ = CURRENT_LOGICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
if let Some(metric) = Lazy::get(&DIRECTORY_ENTRIES_COUNT) {
let _ = metric.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
}
let _ =
NUM_PERSISTENT_FILES_CREATED.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
let _ = PERSISTENT_BYTES_WRITTEN.remove_label_values(&[tenant_id, &shard_id, timeline_id]);

View File

@@ -91,8 +91,8 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000);
/// `tokio_tar` already read the first such block. Read the second all-zeros block,
/// and check that there is no more data after the EOF marker.
///
/// 'tar' command can also write extra blocks of zeros, up to a record
/// size, controlled by the --record-size argument. Ignore them too.
/// XXX: Currently, any trailing data after the EOF marker prints a warning.
/// Perhaps it should be a hard error?
async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 512];
@@ -113,24 +113,17 @@ async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<()
anyhow::bail!("invalid tar EOF marker");
}
// Drain any extra zero-blocks after the EOF marker
// Drain any data after the EOF marker
let mut trailing_bytes = 0;
let mut seen_nonzero_bytes = false;
loop {
let nbytes = reader.read(&mut buf).await?;
trailing_bytes += nbytes;
if !buf.iter().all(|&x| x == 0) {
seen_nonzero_bytes = true;
}
if nbytes == 0 {
break;
}
}
if seen_nonzero_bytes {
anyhow::bail!("unexpected non-zero bytes after the tar archive");
}
if trailing_bytes % 512 != 0 {
anyhow::bail!("unexpected number of zeros ({trailing_bytes}), not divisible by tar block size (512 bytes), after the tar archive");
if trailing_bytes > 0 {
warn!("ignored {trailing_bytes} unexpected bytes after the tar archive");
}
Ok(())
}

View File

@@ -14,7 +14,6 @@ use crate::span::debug_assert_current_span_has_tenant_and_timeline_id_no_shard_i
use crate::walrecord::NeonWalRecord;
use anyhow::{ensure, Context};
use bytes::{Buf, Bytes, BytesMut};
use enum_map::Enum;
use pageserver_api::key::{
dbdir_key_range, is_rel_block_key, is_slru_block_key, rel_block_to_key, rel_dir_to_key,
rel_key_range, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
@@ -156,8 +155,6 @@ impl Timeline {
pending_updates: HashMap::new(),
pending_deletions: Vec::new(),
pending_nblocks: 0,
pending_aux_files: None,
pending_directory_entries: Vec::new(),
lsn,
}
}
@@ -871,15 +868,6 @@ pub struct DatadirModification<'a> {
pending_updates: HashMap<Key, Vec<(Lsn, Value)>>,
pending_deletions: Vec<(Range<Key>, Lsn)>,
pending_nblocks: i64,
// If we already wrote any aux file changes in this modification, stash the latest dir. If set,
// [`Self::put_file`] may assume that it is safe to emit a delta rather than checking
// if AUX_FILES_KEY is already set.
pending_aux_files: Option<AuxFilesDirectory>,
/// For special "directory" keys that store key-value maps, track the size of the map
/// if it was updated in this modification.
pending_directory_entries: Vec<(DirectoryKind, usize)>,
}
impl<'a> DatadirModification<'a> {
@@ -911,7 +899,6 @@ impl<'a> DatadirModification<'a> {
let buf = DbDirectory::ser(&DbDirectory {
dbdirs: HashMap::new(),
})?;
self.pending_directory_entries.push((DirectoryKind::Db, 0));
self.put(DBDIR_KEY, Value::Image(buf.into()));
// Create AuxFilesDirectory
@@ -920,24 +907,16 @@ impl<'a> DatadirModification<'a> {
let buf = TwoPhaseDirectory::ser(&TwoPhaseDirectory {
xids: HashSet::new(),
})?;
self.pending_directory_entries
.push((DirectoryKind::TwoPhase, 0));
self.put(TWOPHASEDIR_KEY, Value::Image(buf.into()));
let buf: Bytes = SlruSegmentDirectory::ser(&SlruSegmentDirectory::default())?.into();
let empty_dir = Value::Image(buf);
self.put(slru_dir_to_key(SlruKind::Clog), empty_dir.clone());
self.pending_directory_entries
.push((DirectoryKind::SlruSegment(SlruKind::Clog), 0));
self.put(
slru_dir_to_key(SlruKind::MultiXactMembers),
empty_dir.clone(),
);
self.pending_directory_entries
.push((DirectoryKind::SlruSegment(SlruKind::Clog), 0));
self.put(slru_dir_to_key(SlruKind::MultiXactOffsets), empty_dir);
self.pending_directory_entries
.push((DirectoryKind::SlruSegment(SlruKind::MultiXactOffsets), 0));
Ok(())
}
@@ -1038,7 +1017,6 @@ impl<'a> DatadirModification<'a> {
let buf = RelDirectory::ser(&RelDirectory {
rels: HashSet::new(),
})?;
self.pending_directory_entries.push((DirectoryKind::Rel, 0));
self.put(
rel_dir_to_key(spcnode, dbnode),
Value::Image(Bytes::from(buf)),
@@ -1061,8 +1039,6 @@ impl<'a> DatadirModification<'a> {
if !dir.xids.insert(xid) {
anyhow::bail!("twophase file for xid {} already exists", xid);
}
self.pending_directory_entries
.push((DirectoryKind::TwoPhase, dir.xids.len()));
self.put(
TWOPHASEDIR_KEY,
Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)),
@@ -1098,8 +1074,6 @@ impl<'a> DatadirModification<'a> {
let mut dir = DbDirectory::des(&buf)?;
if dir.dbdirs.remove(&(spcnode, dbnode)).is_some() {
let buf = DbDirectory::ser(&dir)?;
self.pending_directory_entries
.push((DirectoryKind::Db, dir.dbdirs.len()));
self.put(DBDIR_KEY, Value::Image(buf.into()));
} else {
warn!(
@@ -1137,8 +1111,6 @@ impl<'a> DatadirModification<'a> {
// Didn't exist. Update dbdir
dbdir.dbdirs.insert((rel.spcnode, rel.dbnode), false);
let buf = DbDirectory::ser(&dbdir).context("serialize db")?;
self.pending_directory_entries
.push((DirectoryKind::Db, dbdir.dbdirs.len()));
self.put(DBDIR_KEY, Value::Image(buf.into()));
// and create the RelDirectory
@@ -1153,10 +1125,6 @@ impl<'a> DatadirModification<'a> {
if !rel_dir.rels.insert((rel.relnode, rel.forknum)) {
return Err(RelationError::AlreadyExists);
}
self.pending_directory_entries
.push((DirectoryKind::Rel, rel_dir.rels.len()));
self.put(
rel_dir_key,
Value::Image(Bytes::from(
@@ -1248,9 +1216,6 @@ impl<'a> DatadirModification<'a> {
let buf = self.get(dir_key, ctx).await?;
let mut dir = RelDirectory::des(&buf)?;
self.pending_directory_entries
.push((DirectoryKind::Rel, dir.rels.len()));
if dir.rels.remove(&(rel.relnode, rel.forknum)) {
self.put(dir_key, Value::Image(Bytes::from(RelDirectory::ser(&dir)?)));
} else {
@@ -1286,8 +1251,6 @@ impl<'a> DatadirModification<'a> {
if !dir.segments.insert(segno) {
anyhow::bail!("slru segment {kind:?}/{segno} already exists");
}
self.pending_directory_entries
.push((DirectoryKind::SlruSegment(kind), dir.segments.len()));
self.put(
dir_key,
Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)),
@@ -1332,8 +1295,6 @@ impl<'a> DatadirModification<'a> {
if !dir.segments.remove(&segno) {
warn!("slru segment {:?}/{} does not exist", kind, segno);
}
self.pending_directory_entries
.push((DirectoryKind::SlruSegment(kind), dir.segments.len()));
self.put(
dir_key,
Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)),
@@ -1364,8 +1325,6 @@ impl<'a> DatadirModification<'a> {
if !dir.xids.remove(&xid) {
warn!("twophase file for xid {} does not exist", xid);
}
self.pending_directory_entries
.push((DirectoryKind::TwoPhase, dir.xids.len()));
self.put(
TWOPHASEDIR_KEY,
Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)),
@@ -1381,8 +1340,6 @@ impl<'a> DatadirModification<'a> {
let buf = AuxFilesDirectory::ser(&AuxFilesDirectory {
files: HashMap::new(),
})?;
self.pending_directory_entries
.push((DirectoryKind::AuxFiles, 0));
self.put(AUX_FILES_KEY, Value::Image(Bytes::from(buf)));
Ok(())
}
@@ -1393,76 +1350,28 @@ impl<'a> DatadirModification<'a> {
content: &[u8],
ctx: &RequestContext,
) -> anyhow::Result<()> {
let file_path = path.to_string();
let content = if content.is_empty() {
None
} else {
Some(Bytes::copy_from_slice(content))
};
let dir = if let Some(mut dir) = self.pending_aux_files.take() {
// We already updated aux files in `self`: emit a delta and update our latest value
self.put(
AUX_FILES_KEY,
Value::WalRecord(NeonWalRecord::AuxFile {
file_path: file_path.clone(),
content: content.clone(),
}),
);
dir.upsert(file_path, content);
dir
} else {
// Check if the AUX_FILES_KEY is initialized
match self.get(AUX_FILES_KEY, ctx).await {
Ok(dir_bytes) => {
let mut dir = AuxFilesDirectory::des(&dir_bytes)?;
// Key is already set, we may append a delta
self.put(
AUX_FILES_KEY,
Value::WalRecord(NeonWalRecord::AuxFile {
file_path: file_path.clone(),
content: content.clone(),
}),
);
dir.upsert(file_path, content);
dir
}
Err(
e @ (PageReconstructError::AncestorStopping(_)
| PageReconstructError::Cancelled
| PageReconstructError::AncestorLsnTimeout(_)),
) => {
// Important that we do not interpret a shutdown error as "not found" and thereby
// reset the map.
return Err(e.into());
}
// FIXME: PageReconstructError doesn't have an explicit variant for key-not-found, so
// we are assuming that all _other_ possible errors represents a missing key. If some
// other error occurs, we may incorrectly reset the map of aux files.
Err(PageReconstructError::Other(_) | PageReconstructError::WalRedo(_)) => {
// Key is missing, we must insert an image as the basis for subsequent deltas.
let mut dir = AuxFilesDirectory {
files: HashMap::new(),
};
dir.upsert(file_path, content);
self.put(
AUX_FILES_KEY,
Value::Image(Bytes::from(
AuxFilesDirectory::ser(&dir).context("serialize")?,
)),
);
dir
let mut dir = match self.get(AUX_FILES_KEY, ctx).await {
Ok(buf) => AuxFilesDirectory::des(&buf)?,
Err(e) => {
// This is expected: historical databases do not have the key.
debug!("Failed to get info about AUX files: {}", e);
AuxFilesDirectory {
files: HashMap::new(),
}
}
};
self.pending_directory_entries
.push((DirectoryKind::AuxFiles, dir.files.len()));
self.pending_aux_files = Some(dir);
let path = path.to_string();
if content.is_empty() {
dir.files.remove(&path);
} else {
dir.files.insert(path, Bytes::copy_from_slice(content));
}
self.put(
AUX_FILES_KEY,
Value::Image(Bytes::from(
AuxFilesDirectory::ser(&dir).context("serialize")?,
)),
);
Ok(())
}
@@ -1518,10 +1427,6 @@ impl<'a> DatadirModification<'a> {
self.pending_nblocks = 0;
}
for (kind, count) in std::mem::take(&mut self.pending_directory_entries) {
writer.update_directory_entries_count(kind, count as u64);
}
Ok(())
}
@@ -1559,10 +1464,6 @@ impl<'a> DatadirModification<'a> {
writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ));
}
for (kind, count) in std::mem::take(&mut self.pending_directory_entries) {
writer.update_directory_entries_count(kind, count as u64);
}
Ok(())
}
@@ -1672,18 +1573,8 @@ struct RelDirectory {
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub(crate) struct AuxFilesDirectory {
pub(crate) files: HashMap<String, Bytes>,
}
impl AuxFilesDirectory {
pub(crate) fn upsert(&mut self, key: String, value: Option<Bytes>) {
if let Some(value) = value {
self.files.insert(key, value);
} else {
self.files.remove(&key);
}
}
struct AuxFilesDirectory {
files: HashMap<String, Bytes>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -1697,82 +1588,13 @@ struct SlruSegmentDirectory {
segments: HashSet<u32>,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, enum_map::Enum)]
#[repr(u8)]
pub(crate) enum DirectoryKind {
Db,
TwoPhase,
Rel,
AuxFiles,
SlruSegment(SlruKind),
}
impl DirectoryKind {
pub(crate) const KINDS_NUM: usize = <DirectoryKind as Enum>::LENGTH;
pub(crate) fn offset(&self) -> usize {
self.into_usize()
}
}
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
#[allow(clippy::bool_assert_comparison)]
#[cfg(test)]
mod tests {
use hex_literal::hex;
use utils::id::TimelineId;
use super::*;
use crate::{tenant::harness::TenantHarness, DEFAULT_PG_VERSION};
/// Test a round trip of aux file updates, from DatadirModification to reading back from the Timeline
#[tokio::test]
async fn aux_files_round_trip() -> anyhow::Result<()> {
let name = "aux_files_round_trip";
let harness = TenantHarness::create(name)?;
pub const TIMELINE_ID: TimelineId =
TimelineId::from_array(hex!("11223344556677881122334455667788"));
let (tenant, ctx) = harness.load().await;
let tline = tenant
.create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx)
.await?;
let tline = tline.raw_timeline().unwrap();
// First modification: insert two keys
let mut modification = tline.begin_modification(Lsn(0x1000));
modification.put_file("foo/bar1", b"content1", &ctx).await?;
modification.set_lsn(Lsn(0x1008))?;
modification.put_file("foo/bar2", b"content2", &ctx).await?;
modification.commit(&ctx).await?;
let expect_1008 = HashMap::from([
("foo/bar1".to_string(), Bytes::from_static(b"content1")),
("foo/bar2".to_string(), Bytes::from_static(b"content2")),
]);
let readback = tline.list_aux_files(Lsn(0x1008), &ctx).await?;
assert_eq!(readback, expect_1008);
// Second modification: update one key, remove the other
let mut modification = tline.begin_modification(Lsn(0x2000));
modification.put_file("foo/bar1", b"content3", &ctx).await?;
modification.set_lsn(Lsn(0x2008))?;
modification.put_file("foo/bar2", b"", &ctx).await?;
modification.commit(&ctx).await?;
let expect_2008 =
HashMap::from([("foo/bar1".to_string(), Bytes::from_static(b"content3"))]);
let readback = tline.list_aux_files(Lsn(0x2008), &ctx).await?;
assert_eq!(readback, expect_2008);
// Reading back in time works
let readback = tline.list_aux_files(Lsn(0x1008), &ctx).await?;
assert_eq!(readback, expect_1008);
Ok(())
}
//use super::repo_harness::*;
//use super::*;
/*
fn assert_current_logical_size<R: Repository>(timeline: &DatadirTimeline<R>, lsn: Lsn) {

View File

@@ -28,6 +28,7 @@ use remote_storage::GenericRemoteStorage;
use std::fmt;
use storage_broker::BrokerClientChannel;
use tokio::io::BufReader;
use tokio::runtime::Handle;
use tokio::sync::watch;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -643,10 +644,10 @@ impl Tenant {
// The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if
// we shut down while attaching.
let attach_gate_guard = tenant
.gate
.enter()
.expect("We just created the Tenant: nothing else can have shut it down yet");
let Ok(attach_gate_guard) = tenant.gate.enter() else {
// We just created the Tenant: nothing else can have shut it down yet
unreachable!();
};
// Do all the hard work in the background
let tenant_clone = Arc::clone(&tenant);
@@ -754,27 +755,36 @@ impl Tenant {
AttachType::Normal
};
let preload = match (&mode, &remote_storage) {
(SpawnMode::Create, _) => {
let preload_timer = TENANT.preload.start_timer();
let preload = match mode {
SpawnMode::Create => {
// Don't count the skipped preload into the histogram of preload durations
preload_timer.stop_and_discard();
None
},
(SpawnMode::Normal, Some(remote_storage)) => {
let _preload_timer = TENANT.preload.start_timer();
let res = tenant_clone
.preload(remote_storage, task_mgr::shutdown_token())
.await;
match res {
Ok(p) => Some(p),
Err(e) => {
make_broken(&tenant_clone, anyhow::anyhow!(e));
return Ok(());
}
SpawnMode::Normal => {
match &remote_storage {
Some(remote_storage) => Some(
match tenant_clone
.preload(remote_storage, task_mgr::shutdown_token())
.instrument(
tracing::info_span!(parent: None, "attach_preload", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()),
)
.await {
Ok(p) => {
preload_timer.observe_duration();
p
}
,
Err(e) => {
make_broken(&tenant_clone, anyhow::anyhow!(e));
return Ok(());
}
},
),
None => None,
}
}
(SpawnMode::Normal, None) => {
let _preload_timer = TENANT.preload.start_timer();
None
}
};
// Remote preload is complete.
@@ -810,37 +820,36 @@ impl Tenant {
info!("ready for backgound jobs barrier");
}
let deleted = DeleteTenantFlow::resume_from_attach(
match DeleteTenantFlow::resume_from_attach(
deletion,
&tenant_clone,
preload,
tenants,
&ctx,
)
.await;
if let Err(e) = deleted {
make_broken(&tenant_clone, anyhow::anyhow!(e));
.await
{
Err(err) => {
make_broken(&tenant_clone, anyhow::anyhow!(err));
return Ok(());
}
Ok(()) => return Ok(()),
}
return Ok(());
}
// We will time the duration of the attach phase unless this is a creation (attach will do no work)
let attached = {
let _attach_timer = match mode {
SpawnMode::Create => None,
SpawnMode::Normal => {Some(TENANT.attach.start_timer())}
};
tenant_clone.attach(preload, mode, &ctx).await
let attach_timer = match mode {
SpawnMode::Create => None,
SpawnMode::Normal => {Some(TENANT.attach.start_timer())}
};
match attached {
match tenant_clone.attach(preload, mode, &ctx).await {
Ok(()) => {
info!("attach finished, activating");
if let Some(t)= attach_timer {t.observe_duration();}
tenant_clone.activate(broker_client, None, &ctx);
}
Err(e) => {
if let Some(t)= attach_timer {t.observe_duration();}
make_broken(&tenant_clone, anyhow::anyhow!(e));
}
}
@@ -853,26 +862,34 @@ impl Tenant {
// logical size calculations: if logical size calculation semaphore is saturated,
// then warmup will wait for that before proceeding to the next tenant.
if let AttachType::Warmup(_permit) = attach_type {
let mut futs: FuturesUnordered<_> = tenant_clone.timelines.lock().unwrap().values().cloned().map(|t| t.await_initial_logical_size()).collect();
let mut futs = FuturesUnordered::new();
let timelines: Vec<_> = tenant_clone.timelines.lock().unwrap().values().cloned().collect();
for t in timelines {
futs.push(t.await_initial_logical_size())
}
tracing::info!("Waiting for initial logical sizes while warming up...");
while futs.next().await.is_some() {}
while futs.next().await.is_some() {
}
tracing::info!("Warm-up complete");
}
Ok(())
}
.instrument(tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation)),
.instrument({
let span = tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation);
span.follows_from(Span::current());
span
}),
);
Ok(tenant)
}
#[instrument(skip_all)]
pub(crate) async fn preload(
self: &Arc<Tenant>,
remote_storage: &GenericRemoteStorage,
cancel: CancellationToken,
) -> anyhow::Result<TenantPreload> {
span::debug_assert_current_span_has_tenant_id();
// Get list of remote timelines
// download index files for every tenant timeline
info!("listing remote timelines");
@@ -2877,10 +2894,17 @@ impl Tenant {
let tenant_shard_id = *tenant_shard_id;
let config_path = config_path.to_owned();
let conf_content = conf_content.into_bytes();
VirtualFile::crashsafe_overwrite(config_path.clone(), temp_path, conf_content)
.await
.with_context(|| format!("write tenant {tenant_shard_id} config to {config_path}"))?;
tokio::task::spawn_blocking(move || {
Handle::current().block_on(async move {
let conf_content = conf_content.as_bytes();
VirtualFile::crashsafe_overwrite(&config_path, &temp_path, conf_content)
.await
.with_context(|| {
format!("write tenant {tenant_shard_id} config to {config_path}")
})
})
})
.await??;
Ok(())
}
@@ -2907,12 +2931,17 @@ impl Tenant {
let tenant_shard_id = *tenant_shard_id;
let target_config_path = target_config_path.to_owned();
let conf_content = conf_content.into_bytes();
VirtualFile::crashsafe_overwrite(target_config_path.clone(), temp_path, conf_content)
.await
.with_context(|| {
format!("write tenant {tenant_shard_id} config to {target_config_path}")
})?;
tokio::task::spawn_blocking(move || {
Handle::current().block_on(async move {
let conf_content = conf_content.as_bytes();
VirtualFile::crashsafe_overwrite(&target_config_path, &temp_path, conf_content)
.await
.with_context(|| {
format!("write tenant {tenant_shard_id} config to {target_config_path}")
})
})
})
.await??;
Ok(())
}
@@ -3901,7 +3930,6 @@ pub(crate) mod harness {
use utils::lsn::Lsn;
use crate::deletion_queue::mock::MockDeletionQueue;
use crate::walredo::apply_neon;
use crate::{
config::PageServerConf, repository::Key, tenant::Tenant, walrecord::NeonWalRecord,
};
@@ -3954,8 +3982,6 @@ pub(crate) mod harness {
}
}
#[cfg(test)]
#[derive(Debug)]
enum LoadMode {
Local,
Remote,
@@ -4038,7 +4064,7 @@ pub(crate) mod harness {
info_span!("TenantHarness", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())
}
pub(crate) async fn load(&self) -> (Arc<Tenant>, RequestContext) {
pub async fn load(&self) -> (Arc<Tenant>, RequestContext) {
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
(
self.try_load(&ctx)
@@ -4048,74 +4074,6 @@ pub(crate) mod harness {
)
}
/// For tests that specifically want to exercise the local load path, which does
/// not use remote storage.
pub(crate) async fn try_load_local(
&self,
ctx: &RequestContext,
) -> anyhow::Result<Arc<Tenant>> {
self.do_try_load(ctx, LoadMode::Local).await
}
/// The 'load' in this function is either a local load or a normal attachment,
pub(crate) async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result<Arc<Tenant>> {
// If we have nothing in remote storage, must use load_local instead of attach: attach
// will error out if there are no timelines.
//
// See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate
// this weird state of a Tenant which exists but doesn't have any timelines.
let mode = match self.remote_empty() {
true => LoadMode::Local,
false => LoadMode::Remote,
};
self.do_try_load(ctx, mode).await
}
#[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), ?mode))]
async fn do_try_load(
&self,
ctx: &RequestContext,
mode: LoadMode,
) -> anyhow::Result<Arc<Tenant>> {
let walredo_mgr = Arc::new(WalRedoManager::from(TestRedoManager));
let tenant = Arc::new(Tenant::new(
TenantState::Loading,
self.conf,
AttachedTenantConf::try_from(LocationConf::attached_single(
TenantConfOpt::from(self.tenant_conf),
self.generation,
&ShardParameters::default(),
))
.unwrap(),
// This is a legacy/test code path: sharding isn't supported here.
ShardIdentity::unsharded(),
Some(walredo_mgr),
self.tenant_shard_id,
Some(self.remote_storage.clone()),
self.deletion_queue.new_client(),
));
match mode {
LoadMode::Local => {
tenant.load_local(ctx).await?;
}
LoadMode::Remote => {
let preload = tenant
.preload(&self.remote_storage, CancellationToken::new())
.await?;
tenant.attach(Some(preload), SpawnMode::Normal, ctx).await?;
}
}
tenant.state.send_replace(TenantState::Active);
for timeline in tenant.timelines.lock().unwrap().values() {
timeline.set_state(TimelineState::Active);
}
Ok(tenant)
}
fn remote_empty(&self) -> bool {
let tenant_path = self.conf.tenant_path(&self.tenant_shard_id);
let remote_tenant_dir = self
@@ -4141,6 +4099,77 @@ pub(crate) mod harness {
}
}
async fn do_try_load(
&self,
ctx: &RequestContext,
mode: LoadMode,
) -> anyhow::Result<Arc<Tenant>> {
let walredo_mgr = Arc::new(WalRedoManager::from(TestRedoManager));
let tenant = Arc::new(Tenant::new(
TenantState::Loading,
self.conf,
AttachedTenantConf::try_from(LocationConf::attached_single(
TenantConfOpt::from(self.tenant_conf),
self.generation,
&ShardParameters::default(),
))
.unwrap(),
// This is a legacy/test code path: sharding isn't supported here.
ShardIdentity::unsharded(),
Some(walredo_mgr),
self.tenant_shard_id,
Some(self.remote_storage.clone()),
self.deletion_queue.new_client(),
));
match mode {
LoadMode::Local => {
tenant
.load_local(ctx)
.instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()))
.await?;
}
LoadMode::Remote => {
let preload = tenant
.preload(&self.remote_storage, CancellationToken::new())
.instrument(info_span!("try_load_preload", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()))
.await?;
tenant
.attach(Some(preload), SpawnMode::Normal, ctx)
.instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()))
.await?;
}
}
tenant.state.send_replace(TenantState::Active);
for timeline in tenant.timelines.lock().unwrap().values() {
timeline.set_state(TimelineState::Active);
}
Ok(tenant)
}
/// For tests that specifically want to exercise the local load path, which does
/// not use remote storage.
pub async fn try_load_local(&self, ctx: &RequestContext) -> anyhow::Result<Arc<Tenant>> {
self.do_try_load(ctx, LoadMode::Local).await
}
/// The 'load' in this function is either a local load or a normal attachment,
pub async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result<Arc<Tenant>> {
// If we have nothing in remote storage, must use load_local instead of attach: attach
// will error out if there are no timelines.
//
// See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate
// this weird state of a Tenant which exists but doesn't have any timelines.
let mode = match self.remote_empty() {
true => LoadMode::Local,
false => LoadMode::Remote,
};
self.do_try_load(ctx, mode).await
}
pub fn timeline_path(&self, timeline_id: &TimelineId) -> Utf8PathBuf {
self.conf.timeline_path(&self.tenant_shard_id, timeline_id)
}
@@ -4161,34 +4190,20 @@ pub(crate) mod harness {
records: Vec<(Lsn, NeonWalRecord)>,
_pg_version: u32,
) -> anyhow::Result<Bytes> {
let records_neon = records.iter().all(|r| apply_neon::can_apply_in_neon(&r.1));
let s = format!(
"redo for {} to get to {}, with {} and {} records",
key,
lsn,
if base_img.is_some() {
"base image"
} else {
"no base image"
},
records.len()
);
println!("{s}");
if records_neon {
// For Neon wal records, we can decode without spawning postgres, so do so.
let base_img = base_img.expect("Neon WAL redo requires base image").1;
let mut page = BytesMut::new();
page.extend_from_slice(&base_img);
for (_record_lsn, record) in records {
apply_neon::apply_in_neon(&record, key, &mut page)?;
}
Ok(page.freeze())
} else {
// We never spawn a postgres walredo process in unit tests: just log what we might have done.
let s = format!(
"redo for {} to get to {}, with {} and {} records",
key,
lsn,
if base_img.is_some() {
"base image"
} else {
"no base image"
},
records.len()
);
println!("{s}");
Ok(TEST_IMG(&s))
}
Ok(TEST_IMG(&s))
}
}
}

View File

@@ -11,9 +11,6 @@
//! len < 128: 0XXXXXXX
//! len >= 128: 1XXXXXXX XXXXXXXX XXXXXXXX XXXXXXXX
//!
use bytes::{BufMut, BytesMut};
use tokio_epoll_uring::{BoundedBuf, Slice};
use crate::context::RequestContext;
use crate::page_cache::PAGE_SZ;
use crate::tenant::block_io::BlockCursor;
@@ -103,8 +100,6 @@ pub struct BlobWriter<const BUFFERED: bool> {
offset: u64,
/// A buffer to save on write calls, only used if BUFFERED=true
buf: Vec<u8>,
/// We do tiny writes for the length headers; they need to be in an owned buffer;
io_buf: Option<BytesMut>,
}
impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
@@ -113,7 +108,6 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
inner,
offset: start_offset,
buf: Vec::with_capacity(Self::CAPACITY),
io_buf: Some(BytesMut::new()),
}
}
@@ -123,31 +117,21 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
const CAPACITY: usize = if BUFFERED { PAGE_SZ } else { 0 };
#[inline(always)]
/// Writes the given buffer directly to the underlying `VirtualFile`.
/// You need to make sure that the internal buffer is empty, otherwise
/// data will be written in wrong order.
#[inline(always)]
async fn write_all_unbuffered<B: BoundedBuf>(
&mut self,
src_buf: B,
) -> (B::Buf, Result<(), Error>) {
let (src_buf, res) = self.inner.write_all(src_buf).await;
let nbytes = match res {
Ok(nbytes) => nbytes,
Err(e) => return (src_buf, Err(e)),
};
self.offset += nbytes as u64;
(src_buf, Ok(()))
async fn write_all_unbuffered(&mut self, src_buf: &[u8]) -> Result<(), Error> {
self.inner.write_all(src_buf).await?;
self.offset += src_buf.len() as u64;
Ok(())
}
#[inline(always)]
/// Flushes the internal buffer to the underlying `VirtualFile`.
pub async fn flush_buffer(&mut self) -> Result<(), Error> {
let buf = std::mem::take(&mut self.buf);
let (mut buf, res) = self.inner.write_all(buf).await;
res?;
buf.clear();
self.buf = buf;
self.inner.write_all(&self.buf).await?;
self.buf.clear();
Ok(())
}
@@ -162,91 +146,62 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
}
/// Internal, possibly buffered, write function
async fn write_all<B: BoundedBuf>(&mut self, src_buf: B) -> (B::Buf, Result<(), Error>) {
async fn write_all(&mut self, mut src_buf: &[u8]) -> Result<(), Error> {
if !BUFFERED {
assert!(self.buf.is_empty());
return self.write_all_unbuffered(src_buf).await;
self.write_all_unbuffered(src_buf).await?;
return Ok(());
}
let remaining = Self::CAPACITY - self.buf.len();
let src_buf_len = src_buf.bytes_init();
if src_buf_len == 0 {
return (Slice::into_inner(src_buf.slice_full()), Ok(()));
}
let mut src_buf = src_buf.slice(0..src_buf_len);
// First try to copy as much as we can into the buffer
if remaining > 0 {
let copied = self.write_into_buffer(&src_buf);
src_buf = src_buf.slice(copied..);
let copied = self.write_into_buffer(src_buf);
src_buf = &src_buf[copied..];
}
// Then, if the buffer is full, flush it out
if self.buf.len() == Self::CAPACITY {
if let Err(e) = self.flush_buffer().await {
return (Slice::into_inner(src_buf), Err(e));
}
self.flush_buffer().await?;
}
// Finally, write the tail of src_buf:
// If it wholly fits into the buffer without
// completely filling it, then put it there.
// If not, write it out directly.
let src_buf = if !src_buf.is_empty() {
if !src_buf.is_empty() {
assert_eq!(self.buf.len(), 0);
if src_buf.len() < Self::CAPACITY {
let copied = self.write_into_buffer(&src_buf);
let copied = self.write_into_buffer(src_buf);
// We just verified above that src_buf fits into our internal buffer.
assert_eq!(copied, src_buf.len());
Slice::into_inner(src_buf)
} else {
let (src_buf, res) = self.write_all_unbuffered(src_buf).await;
if let Err(e) = res {
return (src_buf, Err(e));
}
src_buf
self.write_all_unbuffered(src_buf).await?;
}
} else {
Slice::into_inner(src_buf)
};
(src_buf, Ok(()))
}
Ok(())
}
/// Write a blob of data. Returns the offset that it was written to,
/// which can be used to retrieve the data later.
pub async fn write_blob<B: BoundedBuf>(&mut self, srcbuf: B) -> (B::Buf, Result<u64, Error>) {
pub async fn write_blob(&mut self, srcbuf: &[u8]) -> Result<u64, Error> {
let offset = self.offset;
let len = srcbuf.bytes_init();
let mut io_buf = self.io_buf.take().expect("we always put it back below");
io_buf.clear();
let (io_buf, hdr_res) = async {
if len < 128 {
// Short blob. Write a 1-byte length header
io_buf.put_u8(len as u8);
self.write_all(io_buf).await
} else {
// Write a 4-byte length header
if len > 0x7fff_ffff {
return (
io_buf,
Err(Error::new(
ErrorKind::Other,
format!("blob too large ({} bytes)", len),
)),
);
}
let mut len_buf = (len as u32).to_be_bytes();
len_buf[0] |= 0x80;
io_buf.extend_from_slice(&len_buf[..]);
self.write_all(io_buf).await
if srcbuf.len() < 128 {
// Short blob. Write a 1-byte length header
let len_buf = srcbuf.len() as u8;
self.write_all(&[len_buf]).await?;
} else {
// Write a 4-byte length header
if srcbuf.len() > 0x7fff_ffff {
return Err(Error::new(
ErrorKind::Other,
format!("blob too large ({} bytes)", srcbuf.len()),
));
}
let mut len_buf = ((srcbuf.len()) as u32).to_be_bytes();
len_buf[0] |= 0x80;
self.write_all(&len_buf).await?;
}
.await;
self.io_buf = Some(io_buf);
match hdr_res {
Ok(_) => (),
Err(e) => return (Slice::into_inner(srcbuf.slice(..)), Err(e)),
}
let (srcbuf, res) = self.write_all(srcbuf).await;
(srcbuf, res.map(|_| offset))
self.write_all(srcbuf).await?;
Ok(offset)
}
}
@@ -293,14 +248,12 @@ mod tests {
let file = VirtualFile::create(pathbuf.as_path()).await?;
let mut wtr = BlobWriter::<BUFFERED>::new(file, 0);
for blob in blobs.iter() {
let (_, res) = wtr.write_blob(blob.clone()).await;
let offs = res?;
let offs = wtr.write_blob(blob).await?;
offsets.push(offs);
}
// Write out one page worth of zeros so that we can
// read again with read_blk
let (_, res) = wtr.write_blob(vec![0; PAGE_SZ]).await;
let offs = res?;
let offs = wtr.write_blob(&vec![0; PAGE_SZ]).await?;
println!("Writing final blob at offs={offs}");
wtr.flush_buffer().await?;
}

View File

@@ -6,7 +6,7 @@ use pageserver_api::{models::TenantState, shard::TenantShardId};
use remote_storage::{GenericRemoteStorage, RemotePath};
use tokio::sync::OwnedMutexGuard;
use tokio_util::sync::CancellationToken;
use tracing::{error, instrument, Instrument};
use tracing::{error, instrument, Instrument, Span};
use utils::{backoff, completion, crashsafe, fs_ext, id::TimelineId};
@@ -496,7 +496,11 @@ impl DeleteTenantFlow {
};
Ok(())
}
.instrument(tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
.instrument({
let span = tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug());
span.follows_from(Span::current());
span
}),
);
}

View File

@@ -6,7 +6,6 @@ use crate::context::RequestContext;
use crate::page_cache::{self, PAGE_SZ};
use crate::tenant::block_io::{BlockCursor, BlockLease, BlockReader};
use crate::virtual_file::{self, VirtualFile};
use bytes::BytesMut;
use camino::Utf8PathBuf;
use pageserver_api::shard::TenantShardId;
use std::cmp::min;
@@ -27,10 +26,7 @@ pub struct EphemeralFile {
/// An ephemeral file is append-only.
/// We keep the last page, which can still be modified, in [`Self::mutable_tail`].
/// The other pages, which can no longer be modified, are accessed through the page cache.
///
/// None <=> IO is ongoing.
/// Size is fixed to PAGE_SZ at creation time and must not be changed.
mutable_tail: Option<BytesMut>,
mutable_tail: [u8; PAGE_SZ],
}
impl EphemeralFile {
@@ -64,7 +60,7 @@ impl EphemeralFile {
_timeline_id: timeline_id,
file,
len: 0,
mutable_tail: Some(BytesMut::zeroed(PAGE_SZ)),
mutable_tail: [0u8; PAGE_SZ],
})
}
@@ -107,13 +103,7 @@ impl EphemeralFile {
};
} else {
debug_assert_eq!(blknum as u64, self.len / PAGE_SZ as u64);
Ok(BlockLease::EphemeralFileMutableTail(
self.mutable_tail
.as_deref()
.expect("we're not doing IO, it must be Some()")
.try_into()
.expect("we ensure that it's always PAGE_SZ"),
))
Ok(BlockLease::EphemeralFileMutableTail(&self.mutable_tail))
}
}
@@ -145,27 +135,21 @@ impl EphemeralFile {
) -> Result<(), io::Error> {
let mut src_remaining = src;
while !src_remaining.is_empty() {
let dst_remaining = &mut self
.ephemeral_file
.mutable_tail
.as_deref_mut()
.expect("IO is not yet ongoing")[self.off..];
let dst_remaining = &mut self.ephemeral_file.mutable_tail[self.off..];
let n = min(dst_remaining.len(), src_remaining.len());
dst_remaining[..n].copy_from_slice(&src_remaining[..n]);
self.off += n;
src_remaining = &src_remaining[n..];
if self.off == PAGE_SZ {
let mutable_tail = std::mem::take(&mut self.ephemeral_file.mutable_tail)
.expect("IO is not yet ongoing");
let (mutable_tail, res) = self
match self
.ephemeral_file
.file
.write_all_at(mutable_tail, self.blknum as u64 * PAGE_SZ as u64)
.await;
// TODO: If we panic before we can put the mutable_tail back, subsequent calls will fail.
// I.e., the IO isn't retryable if we panic.
self.ephemeral_file.mutable_tail = Some(mutable_tail);
match res {
.write_all_at(
&self.ephemeral_file.mutable_tail,
self.blknum as u64 * PAGE_SZ as u64,
)
.await
{
Ok(_) => {
// Pre-warm the page cache with what we just wrote.
// This isn't necessary for coherency/correctness, but it's how we've always done it.
@@ -185,12 +169,7 @@ impl EphemeralFile {
Ok(page_cache::ReadBufResult::NotFound(mut write_guard)) => {
let buf: &mut [u8] = write_guard.deref_mut();
debug_assert_eq!(buf.len(), PAGE_SZ);
buf.copy_from_slice(
self.ephemeral_file
.mutable_tail
.as_deref()
.expect("IO is not ongoing"),
);
buf.copy_from_slice(&self.ephemeral_file.mutable_tail);
let _ = write_guard.mark_valid();
// pre-warm successful
}
@@ -202,11 +181,7 @@ impl EphemeralFile {
// Zero the buffer for re-use.
// Zeroing is critical for correcntess because the write_blob code below
// and similarly read_blk expect zeroed pages.
self.ephemeral_file
.mutable_tail
.as_deref_mut()
.expect("IO is not ongoing")
.fill(0);
self.ephemeral_file.mutable_tail.fill(0);
// This block is done, move to next one.
self.blknum += 1;
self.off = 0;

View File

@@ -279,7 +279,7 @@ pub async fn save_metadata(
let path = conf.metadata_path(tenant_shard_id, timeline_id);
let temp_path = path_with_suffix_extension(&path, TEMP_FILE_SUFFIX);
let metadata_bytes = data.to_bytes().context("serialize metadata")?;
VirtualFile::crashsafe_overwrite(path, temp_path, metadata_bytes)
VirtualFile::crashsafe_overwrite(&path, &temp_path, &metadata_bytes)
.await
.context("write metadata")?;
Ok(())

View File

@@ -1700,6 +1700,23 @@ impl RemoteTimelineClient {
}
}
}
pub(crate) fn get_layers_metadata(
&self,
layers: Vec<LayerFileName>,
) -> anyhow::Result<Vec<Option<LayerFileMetadata>>> {
let q = self.upload_queue.lock().unwrap();
let q = match &*q {
UploadQueue::Stopped(_) | UploadQueue::Uninitialized => {
anyhow::bail!("queue is in state {}", q.as_str())
}
UploadQueue::Initialized(inner) => inner,
};
let decorated = layers.into_iter().map(|l| q.latest_files.get(&l).cloned());
Ok(decorated.collect())
}
}
pub fn remote_timelines_path(tenant_shard_id: &TenantShardId) -> RemotePath {

View File

@@ -484,9 +484,14 @@ impl<'a> TenantDownloader<'a> {
let temp_path = path_with_suffix_extension(&heatmap_path, TEMP_FILE_SUFFIX);
let context_msg = format!("write tenant {tenant_shard_id} heatmap to {heatmap_path}");
let heatmap_path_bg = heatmap_path.clone();
VirtualFile::crashsafe_overwrite(heatmap_path_bg, temp_path, heatmap_bytes)
.await
.maybe_fatal_err(&context_msg)?;
tokio::task::spawn_blocking(move || {
tokio::runtime::Handle::current().block_on(async move {
VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, &heatmap_bytes).await
})
})
.await
.expect("Blocking task is never aborted")
.maybe_fatal_err(&context_msg)?;
tracing::debug!("Wrote local heatmap to {}", heatmap_path);

View File

@@ -257,12 +257,6 @@ impl LayerAccessStats {
ret
}
/// Get the latest access timestamp, falling back to latest residence event, further falling
/// back to `SystemTime::now` for a usable timestamp for eviction.
pub(crate) fn latest_activity_or_now(&self) -> SystemTime {
self.latest_activity().unwrap_or_else(SystemTime::now)
}
/// Get the latest access timestamp, falling back to latest residence event.
///
/// This function can only return `None` if there has not yet been a call to the
@@ -277,7 +271,7 @@ impl LayerAccessStats {
/// that that type can only be produced by inserting into the layer map.
///
/// [`record_residence_event`]: Self::record_residence_event
fn latest_activity(&self) -> Option<SystemTime> {
pub(crate) fn latest_activity(&self) -> Option<SystemTime> {
let locked = self.0.lock().unwrap();
let inner = &locked.for_eviction_policy;
match inner.last_accesses.recent() {

View File

@@ -416,31 +416,27 @@ impl DeltaLayerWriterInner {
/// The values must be appended in key, lsn order.
///
async fn put_value(&mut self, key: Key, lsn: Lsn, val: Value) -> anyhow::Result<()> {
let (_, res) = self
.put_value_bytes(key, lsn, Value::ser(&val)?, val.will_init())
.await;
res
self.put_value_bytes(key, lsn, &Value::ser(&val)?, val.will_init())
.await
}
async fn put_value_bytes(
&mut self,
key: Key,
lsn: Lsn,
val: Vec<u8>,
val: &[u8],
will_init: bool,
) -> (Vec<u8>, anyhow::Result<()>) {
) -> anyhow::Result<()> {
assert!(self.lsn_range.start <= lsn);
let (val, res) = self.blob_writer.write_blob(val).await;
let off = match res {
Ok(off) => off,
Err(e) => return (val, Err(anyhow::anyhow!(e))),
};
let off = self.blob_writer.write_blob(val).await?;
let blob_ref = BlobRef::new(off, will_init);
let delta_key = DeltaKey::from_key_lsn(&key, lsn);
let res = self.tree.append(&delta_key.0, blob_ref.0);
(val, res.map_err(|e| anyhow::anyhow!(e)))
self.tree.append(&delta_key.0, blob_ref.0)?;
Ok(())
}
fn size(&self) -> u64 {
@@ -461,8 +457,7 @@ impl DeltaLayerWriterInner {
file.seek(SeekFrom::Start(index_start_blk as u64 * PAGE_SZ as u64))
.await?;
for buf in block_buf.blocks {
let (_buf, res) = file.write_all(buf).await;
res?;
file.write_all(buf.as_ref()).await?;
}
assert!(self.lsn_range.start < self.lsn_range.end);
// Fill in the summary on blk 0
@@ -477,12 +472,17 @@ impl DeltaLayerWriterInner {
index_root_blk,
};
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here but it's a pain with Slice<T>
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
Summary::ser_into(&summary, &mut buf)?;
if buf.spilled() {
// This is bad as we only have one free block for the summary
warn!(
"Used more than one page size for summary buffer: {}",
buf.len()
);
}
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf).await;
res?;
file.write_all(&buf).await?;
let metadata = file
.metadata()
@@ -587,9 +587,9 @@ impl DeltaLayerWriter {
&mut self,
key: Key,
lsn: Lsn,
val: Vec<u8>,
val: &[u8],
will_init: bool,
) -> (Vec<u8>, anyhow::Result<()>) {
) -> anyhow::Result<()> {
self.inner
.as_mut()
.unwrap()
@@ -675,12 +675,18 @@ impl DeltaLayer {
let new_summary = rewrite(actual_summary);
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here, but it's a pain with Slice<T>
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
Summary::ser_into(&new_summary, &mut buf).context("serialize")?;
if buf.spilled() {
// The code in DeltaLayerWriterInner just warn!()s for this.
// It should probably error out as well.
return Err(RewriteSummaryError::Other(anyhow::anyhow!(
"Used more than one page size for summary buffer: {}",
buf.len()
)));
}
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf).await;
res?;
file.write_all(&buf).await?;
Ok(())
}
}

View File

@@ -341,12 +341,18 @@ impl ImageLayer {
let new_summary = rewrite(actual_summary);
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here but it's a pain with Slice<T>
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
Summary::ser_into(&new_summary, &mut buf).context("serialize")?;
if buf.spilled() {
// The code in ImageLayerWriterInner just warn!()s for this.
// It should probably error out as well.
return Err(RewriteSummaryError::Other(anyhow::anyhow!(
"Used more than one page size for summary buffer: {}",
buf.len()
)));
}
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf).await;
res?;
file.write_all(&buf).await?;
Ok(())
}
}
@@ -522,11 +528,9 @@ impl ImageLayerWriterInner {
///
/// The page versions must be appended in blknum order.
///
async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> {
async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> {
ensure!(self.key_range.contains(&key));
let (_img, res) = self.blob_writer.write_blob(img).await;
// TODO: re-use the buffer for `img` further upstack
let off = res?;
let off = self.blob_writer.write_blob(img).await?;
let mut keybuf: [u8; KEY_SIZE] = [0u8; KEY_SIZE];
key.write_to_byte_slice(&mut keybuf);
@@ -549,8 +553,7 @@ impl ImageLayerWriterInner {
.await?;
let (index_root_blk, block_buf) = self.tree.finish()?;
for buf in block_buf.blocks {
let (_buf, res) = file.write_all(buf).await;
res?;
file.write_all(buf.as_ref()).await?;
}
// Fill in the summary on blk 0
@@ -565,12 +568,17 @@ impl ImageLayerWriterInner {
index_root_blk,
};
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here but it's a pain with Slice<T>
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
Summary::ser_into(&summary, &mut buf)?;
if buf.spilled() {
// This is bad as we only have one free block for the summary
warn!(
"Used more than one page size for summary buffer: {}",
buf.len()
);
}
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf).await;
res?;
file.write_all(&buf).await?;
let metadata = file
.metadata()
@@ -651,7 +659,7 @@ impl ImageLayerWriter {
///
/// The page versions must be appended in blknum order.
///
pub async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> {
pub async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> {
self.inner.as_mut().unwrap().put_image(key, img).await
}

View File

@@ -383,11 +383,9 @@ impl InMemoryLayer {
for (lsn, pos) in vec_map.as_slice() {
cursor.read_blob_into_buf(*pos, &mut buf, &ctx).await?;
let will_init = Value::des(&buf)?.will_init();
let res;
(buf, res) = delta_layer_writer
.put_value_bytes(key, *lsn, buf, will_init)
.await;
res?;
delta_layer_writer
.put_value_bytes(key, *lsn, &buf, will_init)
.await?;
}
}

View File

@@ -1413,6 +1413,10 @@ impl ResidentLayer {
&self.owner.0.path
}
pub(crate) fn access_stats(&self) -> &LayerAccessStats {
self.owner.access_stats()
}
pub(crate) fn metadata(&self) -> LayerFileMetadata {
self.owner.metadata()
}

View File

@@ -12,9 +12,7 @@ use bytes::Bytes;
use camino::{Utf8Path, Utf8PathBuf};
use enumset::EnumSet;
use fail::fail_point;
use futures::stream::StreamExt;
use itertools::Itertools;
use once_cell::sync::Lazy;
use pageserver_api::{
keyspace::{key_range_size, KeySpaceAccum},
models::{
@@ -35,22 +33,17 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::sync::gate::Gate;
use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
use std::ops::{Deref, Range};
use std::pin::pin;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::sync::{Arc, Mutex, RwLock, Weak};
use std::time::{Duration, Instant, SystemTime};
use std::{
array,
collections::{BTreeMap, BinaryHeap, HashMap, HashSet},
sync::atomic::AtomicU64,
};
use std::{
cmp::{max, min, Ordering},
ops::ControlFlow,
};
use crate::pgdatadir_mapping::DirectoryKind;
use crate::tenant::timeline::logical_size::CurrentLogicalSize;
use crate::tenant::{
layer_map::{LayerMap, SearchResult},
@@ -112,7 +105,7 @@ use self::logical_size::LogicalSize;
use self::walreceiver::{WalReceiver, WalReceiverConf};
use super::config::TenantConf;
use super::remote_timeline_client::index::IndexPart;
use super::remote_timeline_client::index::{IndexLayerMetadata, IndexPart};
use super::remote_timeline_client::RemoteTimelineClient;
use super::secondary::heatmap::{HeatMapLayer, HeatMapTimeline};
use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf};
@@ -264,8 +257,6 @@ pub struct Timeline {
// in `crate::page_service` writes these metrics.
pub(crate) query_metrics: crate::metrics::SmgrQueryTimePerTimeline,
directory_metrics: [AtomicU64; DirectoryKind::KINDS_NUM],
/// Ensures layers aren't frozen by checkpointer between
/// [`Timeline::get_layer_for_write`] and layer reads.
/// Locked automatically by [`TimelineWriter`] and checkpointer.
@@ -798,10 +789,6 @@ impl Timeline {
self.metrics.resident_physical_size_get()
}
pub(crate) fn get_directory_metrics(&self) -> [u64; DirectoryKind::KINDS_NUM] {
array::from_fn(|idx| self.directory_metrics[idx].load(AtomicOrdering::Relaxed))
}
///
/// Wait until WAL has been received and processed up to this LSN.
///
@@ -1471,7 +1458,7 @@ impl Timeline {
generation,
shard_identity,
pg_version,
layers: Default::default(),
layers: Arc::new(tokio::sync::RwLock::new(LayerManager::create())),
wanted_image_layers: Mutex::new(None),
walredo_mgr,
@@ -1508,8 +1495,6 @@ impl Timeline {
&timeline_id,
),
directory_metrics: array::from_fn(|_| AtomicU64::new(0)),
flush_loop_state: Mutex::new(FlushLoopState::NotStarted),
layer_flush_start_tx,
@@ -2278,29 +2263,6 @@ impl Timeline {
}
}
pub(crate) fn update_directory_entries_count(&self, kind: DirectoryKind, count: u64) {
self.directory_metrics[kind.offset()].store(count, AtomicOrdering::Relaxed);
let aux_metric =
self.directory_metrics[DirectoryKind::AuxFiles.offset()].load(AtomicOrdering::Relaxed);
let sum_of_entries = self
.directory_metrics
.iter()
.map(|v| v.load(AtomicOrdering::Relaxed))
.sum();
// Set a high general threshold and a lower threshold for the auxiliary files,
// as we can have large numbers of relations in the db directory.
const SUM_THRESHOLD: u64 = 5000;
const AUX_THRESHOLD: u64 = 1000;
if sum_of_entries >= SUM_THRESHOLD || aux_metric >= AUX_THRESHOLD {
self.metrics
.directory_entries_count_gauge
.set(sum_of_entries);
} else if let Some(metric) = Lazy::get(&self.metrics.directory_entries_count_gauge) {
metric.set(sum_of_entries);
}
}
async fn find_layer(&self, layer_file_name: &str) -> Option<Layer> {
let guard = self.layers.read().await;
for historic_layer in guard.layer_map().iter_historic_layers() {
@@ -2321,28 +2283,45 @@ impl Timeline {
/// should treat this as a cue to simply skip doing any heatmap uploading
/// for this timeline.
pub(crate) async fn generate_heatmap(&self) -> Option<HeatMapTimeline> {
// no point in heatmaps without remote client
let _remote_client = self.remote_client.as_ref()?;
let eviction_info = self.get_local_layers_for_disk_usage_eviction().await;
if !self.is_active() {
return None;
}
let remote_client = match &self.remote_client {
Some(c) => c,
None => return None,
};
let guard = self.layers.read().await;
let layer_file_names = eviction_info
.resident_layers
.iter()
.map(|l| l.layer.get_name())
.collect::<Vec<_>>();
let resident = guard.resident_layers().map(|layer| {
let last_activity_ts = layer.access_stats().latest_activity_or_now();
let decorated = match remote_client.get_layers_metadata(layer_file_names) {
Ok(d) => d,
Err(_) => {
// Getting metadata only fails on Timeline in bad state.
return None;
}
};
HeatMapLayer::new(
layer.layer_desc().filename(),
layer.metadata().into(),
last_activity_ts,
)
let heatmap_layers = std::iter::zip(
eviction_info.resident_layers.into_iter(),
decorated.into_iter(),
)
.filter_map(|(layer, remote_info)| {
remote_info.map(|remote_info| {
HeatMapLayer::new(
layer.layer.get_name(),
IndexLayerMetadata::from(remote_info),
layer.last_activity_ts,
)
})
});
let layers = resident.collect().await;
Some(HeatMapTimeline::new(self.timeline_id, layers))
Some(HeatMapTimeline::new(
self.timeline_id,
heatmap_layers.collect(),
))
}
}
@@ -3349,7 +3328,7 @@ impl Timeline {
}
};
image_layer_writer.put_image(img_key, img).await?;
image_layer_writer.put_image(img_key, &img).await?;
}
}
@@ -4683,24 +4662,41 @@ impl Timeline {
/// Returns non-remote layers for eviction.
pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo {
let guard = self.layers.read().await;
let layers = guard.layer_map();
let mut max_layer_size: Option<u64> = None;
let mut resident_layers = Vec::new();
let resident_layers = guard
.resident_layers()
.map(|layer| {
let file_size = layer.layer_desc().file_size;
max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size)));
for l in layers.iter_historic_layers() {
let file_size = l.file_size();
max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size)));
let last_activity_ts = layer.access_stats().latest_activity_or_now();
let l = guard.get_from_desc(&l);
EvictionCandidate {
layer: layer.into(),
last_activity_ts,
relative_last_activity: finite_f32::FiniteF32::ZERO,
let l = match l.keep_resident().await {
Ok(Some(l)) => l,
Ok(None) => continue,
Err(e) => {
// these should not happen, but we cannot make them statically impossible right
// now.
tracing::warn!(layer=%l, "failed to keep the layer resident: {e:#}");
continue;
}
})
.collect()
.await;
};
let last_activity_ts = l.access_stats().latest_activity().unwrap_or_else(|| {
// We only use this fallback if there's an implementation error.
// `latest_activity` already does rate-limited warn!() log.
debug!(layer=%l, "last_activity returns None, using SystemTime::now");
SystemTime::now()
});
resident_layers.push(EvictionCandidate {
layer: l.drop_eviction_guard().into(),
last_activity_ts,
relative_last_activity: finite_f32::FiniteF32::ZERO,
});
}
DiskUsageEvictionInfo {
max_layer_size,

View File

@@ -6,7 +6,7 @@ use std::{
use anyhow::Context;
use pageserver_api::{models::TimelineState, shard::TenantShardId};
use tokio::sync::OwnedMutexGuard;
use tracing::{debug, error, info, instrument, warn, Instrument};
use tracing::{debug, error, info, instrument, warn, Instrument, Span};
use utils::{crashsafe, fs_ext, id::TimelineId};
use crate::{
@@ -541,7 +541,12 @@ impl DeleteTimelineFlow {
};
Ok(())
}
.instrument(tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id)),
.instrument({
let span =
tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id);
span.follows_from(Span::current());
span
}),
);
}

View File

@@ -239,7 +239,12 @@ impl Timeline {
}
};
let last_activity_ts = hist_layer.access_stats().latest_activity_or_now();
let last_activity_ts = hist_layer.access_stats().latest_activity().unwrap_or_else(|| {
// We only use this fallback if there's an implementation error.
// `latest_activity` already does rate-limited warn!() log.
debug!(layer=%hist_layer, "last_activity returns None, using SystemTime::now");
SystemTime::now()
});
let no_activity_for = match now.duration_since(last_activity_ts) {
Ok(d) => d,

View File

@@ -1,5 +1,4 @@
use anyhow::{bail, ensure, Context, Result};
use futures::StreamExt;
use pageserver_api::shard::TenantShardId;
use std::{collections::HashMap, sync::Arc};
use tracing::trace;
@@ -21,13 +20,19 @@ use crate::{
};
/// Provides semantic APIs to manipulate the layer map.
#[derive(Default)]
pub(crate) struct LayerManager {
layer_map: LayerMap,
layer_fmgr: LayerFileManager<Layer>,
}
impl LayerManager {
pub(crate) fn create() -> Self {
Self {
layer_map: LayerMap::default(),
layer_fmgr: LayerFileManager::new(),
}
}
pub(crate) fn get_from_desc(&self, desc: &PersistentLayerDesc) -> Layer {
self.layer_fmgr.get_from_desc(desc)
}
@@ -241,32 +246,6 @@ impl LayerManager {
layer.delete_on_drop();
}
pub(crate) fn resident_layers(&self) -> impl futures::stream::Stream<Item = Layer> + '_ {
// for small layer maps, we most likely have all resident, but for larger more are likely
// to be evicted assuming lots of layers correlated with longer lifespan.
let layers = self
.layer_map()
.iter_historic_layers()
.map(|desc| self.get_from_desc(&desc));
let layers = futures::stream::iter(layers);
layers.filter_map(|layer| async move {
// TODO(#6028): this query does not really need to see the ResidentLayer
match layer.keep_resident().await {
Ok(Some(layer)) => Some(layer.drop_eviction_guard()),
Ok(None) => None,
Err(e) => {
// these should not happen, but we cannot make them statically impossible right
// now.
tracing::warn!(%layer, "failed to keep the layer resident: {e:#}");
None
}
}
})
}
pub(crate) fn contains(&self, layer: &Layer) -> bool {
self.layer_fmgr.contains(layer)
}
@@ -274,12 +253,6 @@ impl LayerManager {
pub(crate) struct LayerFileManager<T>(HashMap<PersistentLayerKey, T>);
impl<T> Default for LayerFileManager<T> {
fn default() -> Self {
Self(HashMap::default())
}
}
impl<T: AsLayerDesc + Clone> LayerFileManager<T> {
fn get_from_desc(&self, desc: &PersistentLayerDesc) -> T {
// The assumption for the `expect()` is that all code maintains the following invariant:
@@ -302,6 +275,10 @@ impl<T: AsLayerDesc + Clone> LayerFileManager<T> {
self.0.contains_key(&layer.layer_desc().key())
}
pub(crate) fn new() -> Self {
Self(HashMap::new())
}
pub(crate) fn remove(&mut self, layer: &T) {
let present = self.0.remove(&layer.layer_desc().key());
if present.is_none() && cfg!(debug_assertions) {

View File

@@ -19,13 +19,14 @@ use once_cell::sync::OnceCell;
use pageserver_api::shard::TenantShardId;
use std::fs::{self, File};
use std::io::{Error, ErrorKind, Seek, SeekFrom};
use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice};
use tokio_epoll_uring::IoBufMut;
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::fs::FileExt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::time::Instant;
use utils::fs_ext;
pub use pageserver_api::models::virtual_file as api;
pub(crate) mod io_engine;
@@ -403,34 +404,46 @@ impl VirtualFile {
Ok(vfile)
}
/// Async version of [`::utils::crashsafe::overwrite`].
/// Writes a file to the specified `final_path` in a crash safe fasion
///
/// # NB:
///
/// Doesn't actually use the [`VirtualFile`] file descriptor cache, but,
/// it did at an earlier time.
/// And it will use this module's [`io_engine`] in the near future, so, leaving it here.
pub async fn crashsafe_overwrite<B: BoundedBuf<Buf = Buf> + Send, Buf: IoBuf + Send>(
final_path: Utf8PathBuf,
tmp_path: Utf8PathBuf,
content: B,
/// The file is first written to the specified tmp_path, and in a second
/// step, the tmp path is renamed to the final path. As renames are
/// atomic, a crash during the write operation will never leave behind a
/// partially written file.
pub async fn crashsafe_overwrite(
final_path: &Utf8Path,
tmp_path: &Utf8Path,
content: &[u8],
) -> std::io::Result<()> {
// TODO: use tokio_epoll_uring if configured as `io_engine`.
// See https://github.com/neondatabase/neon/issues/6663
tokio::task::spawn_blocking(move || {
let slice_storage;
let content_len = content.bytes_init();
let content = if content.bytes_init() > 0 {
slice_storage = Some(content.slice(0..content_len));
slice_storage.as_deref().expect("just set it to Some()")
} else {
&[]
};
utils::crashsafe::overwrite(&final_path, &tmp_path, content)
})
.await
.expect("blocking task is never aborted")
let Some(final_path_parent) = final_path.parent() else {
return Err(std::io::Error::from_raw_os_error(
nix::errno::Errno::EINVAL as i32,
));
};
std::fs::remove_file(tmp_path).or_else(fs_ext::ignore_not_found)?;
let mut file = Self::open_with_options(
tmp_path,
OpenOptions::new()
.write(true)
// Use `create_new` so that, if we race with ourselves or something else,
// we bail out instead of causing damage.
.create_new(true),
)
.await?;
file.write_all(content).await?;
file.sync_all().await?;
drop(file); // before the rename, that's important!
// renames are atomic
std::fs::rename(tmp_path, final_path)?;
// Only open final path parent dirfd now, so that this operation only
// ever holds one VirtualFile fd at a time. That's important because
// the current `find_victim_slot` impl might pick the same slot for both
// VirtualFile., and it eventually does a blocking write lock instead of
// try_lock.
let final_parent_dirfd =
Self::open_with_options(final_path_parent, OpenOptions::new().read(true)).await?;
final_parent_dirfd.sync_all().await?;
Ok(())
}
/// Call File::sync_all() on the underlying File.
@@ -568,69 +581,43 @@ impl VirtualFile {
}
// Copied from https://doc.rust-lang.org/1.72.0/src/std/os/unix/fs.rs.html#219-235
pub async fn write_all_at<B: BoundedBuf>(
&self,
buf: B,
mut offset: u64,
) -> (B::Buf, Result<(), Error>) {
let buf_len = buf.bytes_init();
if buf_len == 0 {
return (Slice::into_inner(buf.slice_full()), Ok(()));
}
let mut buf = buf.slice(0..buf_len);
pub async fn write_all_at(&self, mut buf: &[u8], mut offset: u64) -> Result<(), Error> {
while !buf.is_empty() {
// TODO: push `buf` further down
match self.write_at(&buf, offset).await {
match self.write_at(buf, offset).await {
Ok(0) => {
return (
Slice::into_inner(buf),
Err(Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
)),
);
return Err(Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
Ok(n) => {
buf = buf.slice(n..);
buf = &buf[n..];
offset += n as u64;
}
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return (Slice::into_inner(buf), Err(e)),
Err(e) => return Err(e),
}
}
(Slice::into_inner(buf), Ok(()))
Ok(())
}
/// Writes `buf.slice(0..buf.bytes_init())`.
/// Returns the IoBuf that is underlying the BoundedBuf `buf`.
/// I.e., the returned value's `bytes_init()` method returns something different than the `bytes_init()` that was passed in.
/// It's quite brittle and easy to mis-use, so, we return the size in the Ok() variant.
pub async fn write_all<B: BoundedBuf>(&mut self, buf: B) -> (B::Buf, Result<usize, Error>) {
let nbytes = buf.bytes_init();
if nbytes == 0 {
return (Slice::into_inner(buf.slice_full()), Ok(0));
}
let mut buf = buf.slice(0..nbytes);
pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), Error> {
while !buf.is_empty() {
// TODO: push `Slice` further down
match self.write(&buf).await {
match self.write(buf).await {
Ok(0) => {
return (
Slice::into_inner(buf),
Err(Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
)),
);
return Err(Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
Ok(n) => {
buf = buf.slice(n..);
buf = &buf[n..];
}
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return (Slice::into_inner(buf), Err(e)),
Err(e) => return Err(e),
}
}
(Slice::into_inner(buf), Ok(nbytes))
Ok(())
}
async fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
@@ -689,6 +676,7 @@ where
F: FnMut(tokio_epoll_uring::Slice<B>, u64) -> Fut,
Fut: std::future::Future<Output = (tokio_epoll_uring::Slice<B>, std::io::Result<usize>)>,
{
use tokio_epoll_uring::BoundedBuf;
let mut buf: tokio_epoll_uring::Slice<B> = buf.slice_full(); // includes all the uninitialized memory
while buf.bytes_total() != 0 {
let res;
@@ -1063,19 +1051,10 @@ mod tests {
MaybeVirtualFile::File(file) => file.read_exact_at(&mut buf, offset).map(|()| buf),
}
}
async fn write_all_at<B: BoundedBuf>(&self, buf: B, offset: u64) -> Result<(), Error> {
async fn write_all_at(&self, buf: &[u8], offset: u64) -> Result<(), Error> {
match self {
MaybeVirtualFile::VirtualFile(file) => {
let (_buf, res) = file.write_all_at(buf, offset).await;
res
}
MaybeVirtualFile::File(file) => {
let buf_len = buf.bytes_init();
if buf_len == 0 {
return Ok(());
}
file.write_all_at(&buf.slice(0..buf_len), offset)
}
MaybeVirtualFile::VirtualFile(file) => file.write_all_at(buf, offset).await,
MaybeVirtualFile::File(file) => file.write_all_at(buf, offset),
}
}
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
@@ -1084,19 +1063,10 @@ mod tests {
MaybeVirtualFile::File(file) => file.seek(pos),
}
}
async fn write_all<B: BoundedBuf>(&mut self, buf: B) -> Result<(), Error> {
async fn write_all(&mut self, buf: &[u8]) -> Result<(), Error> {
match self {
MaybeVirtualFile::VirtualFile(file) => {
let (_buf, res) = file.write_all(buf).await;
res.map(|_| ())
}
MaybeVirtualFile::File(file) => {
let buf_len = buf.bytes_init();
if buf_len == 0 {
return Ok(());
}
file.write_all(&buf.slice(0..buf_len))
}
MaybeVirtualFile::VirtualFile(file) => file.write_all(buf).await,
MaybeVirtualFile::File(file) => file.write_all(buf),
}
}
@@ -1171,7 +1141,7 @@ mod tests {
.to_owned(),
)
.await?;
file_a.write_all(b"foobar".to_vec()).await?;
file_a.write_all(b"foobar").await?;
// cannot read from a file opened in write-only mode
let _ = file_a.read_string().await.unwrap_err();
@@ -1180,7 +1150,7 @@ mod tests {
let mut file_a = openfunc(path_a, OpenOptions::new().read(true).to_owned()).await?;
// cannot write to a file opened in read-only mode
let _ = file_a.write_all(b"bar".to_vec()).await.unwrap_err();
let _ = file_a.write_all(b"bar").await.unwrap_err();
// Try simple read
assert_eq!("foobar", file_a.read_string().await?);
@@ -1222,8 +1192,8 @@ mod tests {
.to_owned(),
)
.await?;
file_b.write_all_at(b"BAR".to_vec(), 3).await?;
file_b.write_all_at(b"FOO".to_vec(), 0).await?;
file_b.write_all_at(b"BAR", 3).await?;
file_b.write_all_at(b"FOO", 0).await?;
assert_eq!(file_b.read_string_at(2, 3).await?, "OBA");
@@ -1323,7 +1293,7 @@ mod tests {
let path = testdir.join("myfile");
let tmp_path = testdir.join("myfile.tmp");
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo")
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap());
@@ -1332,7 +1302,7 @@ mod tests {
assert!(!tmp_path.exists());
drop(file);
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar")
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap());
@@ -1354,7 +1324,7 @@ mod tests {
std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap();
assert!(tmp_path.exists());
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo")
.await
.unwrap();

View File

@@ -346,7 +346,7 @@ impl WalIngest {
let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK;
if info == pg_constants::XLOG_LOGICAL_MESSAGE {
let xlrec = crate::walrecord::XlLogicalMessage::decode(&mut buf);
let xlrec = XlLogicalMessage::decode(&mut buf);
let prefix = std::str::from_utf8(&buf[0..xlrec.prefix_size - 1])?;
let message = &buf[xlrec.prefix_size..xlrec.prefix_size + xlrec.message_size];
if prefix == "neon-test" {

View File

@@ -44,11 +44,6 @@ pub enum NeonWalRecord {
moff: MultiXactOffset,
members: Vec<MultiXactMember>,
},
/// Update the map of AUX files, either writing or dropping an entry
AuxFile {
file_path: String,
content: Option<Bytes>,
},
}
impl NeonWalRecord {

View File

@@ -22,7 +22,7 @@
mod process;
/// Code to apply [`NeonWalRecord`]s.
pub(crate) mod apply_neon;
mod apply_neon;
use crate::config::PageServerConf;
use crate::metrics::{

View File

@@ -1,8 +1,7 @@
use crate::pgdatadir_mapping::AuxFilesDirectory;
use crate::walrecord::NeonWalRecord;
use anyhow::Context;
use byteorder::{ByteOrder, LittleEndian};
use bytes::{BufMut, BytesMut};
use bytes::BytesMut;
use pageserver_api::key::{key_to_rel_block, key_to_slru_block, Key};
use pageserver_api::reltag::SlruKind;
use postgres_ffi::pg_constants;
@@ -13,7 +12,6 @@ use postgres_ffi::v14::nonrelfile_utils::{
};
use postgres_ffi::BLCKSZ;
use tracing::*;
use utils::bin_ser::BeSer;
/// Can this request be served by neon redo functions
/// or we need to pass it to wal-redo postgres process?
@@ -232,72 +230,6 @@ pub(crate) fn apply_in_neon(
LittleEndian::write_u32(&mut page[memberoff..memberoff + 4], member.xid);
}
}
NeonWalRecord::AuxFile { file_path, content } => {
let mut dir = AuxFilesDirectory::des(page)?;
dir.upsert(file_path.clone(), content.clone());
page.clear();
let mut writer = page.writer();
dir.ser_into(&mut writer)?;
}
}
Ok(())
}
#[cfg(test)]
mod test {
use bytes::Bytes;
use pageserver_api::key::AUX_FILES_KEY;
use super::*;
use std::collections::HashMap;
use crate::{pgdatadir_mapping::AuxFilesDirectory, walrecord::NeonWalRecord};
/// Test [`apply_in_neon`]'s handling of NeonWalRecord::AuxFile
#[test]
fn apply_aux_file_deltas() -> anyhow::Result<()> {
let base_dir = AuxFilesDirectory {
files: HashMap::from([
("two".to_string(), Bytes::from_static(b"content0")),
("three".to_string(), Bytes::from_static(b"contentX")),
]),
};
let base_image = AuxFilesDirectory::ser(&base_dir)?;
let deltas = vec![
// Insert
NeonWalRecord::AuxFile {
file_path: "one".to_string(),
content: Some(Bytes::from_static(b"content1")),
},
// Update
NeonWalRecord::AuxFile {
file_path: "two".to_string(),
content: Some(Bytes::from_static(b"content99")),
},
// Delete
NeonWalRecord::AuxFile {
file_path: "three".to_string(),
content: None,
},
];
let file_path = AUX_FILES_KEY;
let mut page = BytesMut::from_iter(base_image);
for record in deltas {
apply_in_neon(&record, file_path, &mut page)?;
}
let reconstructed = AuxFilesDirectory::des(&page)?;
let expect = HashMap::from([
("one".to_string(), Bytes::from_static(b"content1")),
("two".to_string(), Bytes::from_static(b"content99")),
]);
assert_eq!(reconstructed.files, expect);
Ok(())
}
}

View File

@@ -3079,6 +3079,14 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id)
XLogRecGetBlockTag(record, block_id, &rinfo, &forknum, &blkno);
#endif
/*
* Out of an abundance of caution, we always run redo on shared catalogs,
* regardless of whether the block is stored in shared buffers. See also
* this function's top comment.
*/
if (!OidIsValid(NInfoGetDbOid(rinfo)))
return false;
CopyNRelFileInfoToBufTag(tag, rinfo);
tag.forkNum = forknum;
tag.blockNum = blkno;
@@ -3092,28 +3100,17 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id)
*/
LWLockAcquire(partitionLock, LW_SHARED);
/*
* Out of an abundance of caution, we always run redo on shared catalogs,
* regardless of whether the block is stored in shared buffers. See also
* this function's top comment.
*/
if (!OidIsValid(NInfoGetDbOid(rinfo)))
{
no_redo_needed = false;
}
else
{
/* Try to find the relevant buffer */
buffer = BufTableLookup(&tag, hash);
/* Try to find the relevant buffer */
buffer = BufTableLookup(&tag, hash);
no_redo_needed = buffer < 0;
no_redo_needed = buffer < 0;
}
/* In both cases st lwlsn past this WAL record */
SetLastWrittenLSNForBlock(end_recptr, rinfo, forknum, blkno);
/*
* we don't have the buffer in memory, update lwLsn past this record, also
* evict page from file cache
* evict page fro file cache
*/
if (no_redo_needed)
lfc_evict(rinfo, forknum, blkno);

View File

@@ -688,7 +688,7 @@ RecvAcceptorGreeting(Safekeeper *sk)
if (!AsyncReadMessage(sk, (AcceptorProposerMessage *) &sk->greetResponse))
return;
wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s, term=" INT64_FORMAT, sk->host, sk->port, sk->greetResponse.term);
wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s", sk->host, sk->port);
/* Protocol is all good, move to voting. */
sk->state = SS_VOTING;
@@ -922,7 +922,6 @@ static void
DetermineEpochStartLsn(WalProposer *wp)
{
TermHistory *dth;
int n_ready = 0;
wp->propEpochStartLsn = InvalidXLogRecPtr;
wp->donorEpoch = 0;
@@ -933,8 +932,6 @@ DetermineEpochStartLsn(WalProposer *wp)
{
if (wp->safekeeper[i].state == SS_IDLE)
{
n_ready++;
if (GetEpoch(&wp->safekeeper[i]) > wp->donorEpoch ||
(GetEpoch(&wp->safekeeper[i]) == wp->donorEpoch &&
wp->safekeeper[i].voteResponse.flushLsn > wp->propEpochStartLsn))
@@ -961,16 +958,6 @@ DetermineEpochStartLsn(WalProposer *wp)
}
}
if (n_ready < wp->quorum)
{
/*
* This is a rare case that can be triggered if safekeeper has voted and disconnected.
* In this case, its state will not be SS_IDLE and its vote cannot be used, because
* we clean up `voteResponse` in `ShutdownConnection`.
*/
wp_log(FATAL, "missing majority of votes, collected %d, expected %d, got %d", wp->n_votes, wp->quorum, n_ready);
}
/*
* If propEpochStartLsn is 0, it means flushLsn is 0 everywhere, we are bootstrapping
* and nothing was committed yet. Start streaming then from the basebackup LSN.

View File

@@ -486,8 +486,6 @@ typedef struct walproposer_api
*
* On success, the data is placed in *buf. It is valid until the next call
* to this function.
*
* Returns PG_ASYNC_READ_FAIL on closed connection.
*/
PGAsyncReadResult (*conn_async_read) (Safekeeper *sk, char **buf, int *amount);
@@ -534,13 +532,6 @@ typedef struct walproposer_api
* Returns 0 if timeout is reached, 1 if some event happened. Updates
* events mask to indicate events and sets sk to the safekeeper which has
* an event.
*
* On timeout, events is set to WL_NO_EVENTS. On socket event, events is
* set to WL_SOCKET_READABLE and/or WL_SOCKET_WRITEABLE. When socket is
* closed, events is set to WL_SOCKET_READABLE.
*
* WL_SOCKET_WRITEABLE is usually set only when we need to flush the buffer.
* It can be returned only if caller asked for this event in the last *_event_set call.
*/
int (*wait_event_set) (WalProposer *wp, long timeout, Safekeeper **sk, uint32 *events);

View File

@@ -36,6 +36,9 @@ pub enum AuthErrorImpl {
#[error(transparent)]
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] console::errors::WakeComputeError),
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
@@ -116,6 +119,7 @@ impl UserFacingError for AuthError {
match self.0.as_ref() {
Link(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
WakeCompute(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
AuthFailed(_) => self.to_string(),
BadAuthMethod(_) => self.to_string(),
@@ -135,6 +139,7 @@ impl ReportableError for AuthError {
match self.0.as_ref() {
Link(e) => e.get_error_kind(),
GetAuthInfo(e) => e.get_error_kind(),
WakeCompute(e) => e.get_error_kind(),
Sasl(e) => e.get_error_kind(),
AuthFailed(_) => crate::error::ErrorKind::User,
BadAuthMethod(_) => crate::error::ErrorKind::User,

View File

@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
use crate::cache::Cached;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::{AuthSecret, NodeInfo};
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::wake_compute::wake_compute;
use crate::proxy::NeonOptions;
use crate::stream::Stream;
use crate::{
@@ -26,6 +26,7 @@ use crate::{
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -55,11 +56,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum BackendType<'a, T, D> {
pub enum BackendType<'a, T> {
/// Cloud API (V2).
Console(MaybeOwned<'a, ConsoleBackend>, T),
/// Authentication via a web browser.
Link(MaybeOwned<'a, url::ApiUrl>, D),
Link(MaybeOwned<'a, url::ApiUrl>),
}
pub trait TestBackend: Send + Sync + 'static {
@@ -70,7 +71,7 @@ pub trait TestBackend: Send + Sync + 'static {
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, (), ()> {
impl std::fmt::Display for BackendType<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
@@ -85,50 +86,51 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
}
impl<T, D> BackendType<'_, T, D> {
impl<T> BackendType<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
pub fn as_ref(&self) -> BackendType<'_, &T> {
use BackendType::*;
match self {
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
Link(c) => Link(MaybeOwned::Borrowed(c)),
}
}
}
impl<'a, T, D> BackendType<'a, T, D> {
impl<'a, T> BackendType<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
use BackendType::*;
match self {
Console(c, x) => Console(c, f(x)),
Link(c, x) => Link(c, x),
}
}
}
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c, x) => Ok(Link(c, x)),
Link(c) => Link(c),
}
}
}
pub struct ComputeCredentials {
impl<'a, T, E> BackendType<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c) => Ok(Link(c)),
}
}
}
pub struct ComputeCredentials<T> {
pub info: ComputeUserInfo,
pub keys: ComputeCredentialKeys,
pub keys: T,
}
#[derive(Debug, Clone)]
@@ -151,6 +153,7 @@ impl ComputeUserInfo {
}
pub enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
}
@@ -185,21 +188,19 @@ async fn auth_quirks(
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
.await?;
ctx.set_endpoint_id(res.info.endpoint.clone());
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
_ => unreachable!("password hack should return a password"),
};
(res.info, Some(password))
(res.info, Some(res.keys))
}
Ok(info) => (info, None),
};
@@ -253,7 +254,7 @@ async fn authenticate_with_secret(
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
if let Some(password) = unauthenticated_password {
let auth_outcome = validate_password_and_exchange(&password, secret)?;
let keys = match auth_outcome {
@@ -275,22 +276,21 @@ async fn authenticate_with_secret(
// Perform cleartext auth if we're allowed to do that.
// Currently, we use it for websocket connections (latency).
if allow_cleartext {
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
}
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(ctx, info, client, config, secret).await
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
}
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<EndpointId> {
use BackendType::*;
match self {
Console(_, user_info) => user_info.endpoint_id.clone(),
Link(_, _) => Some("link".into()),
Link(_) => Some("link".into()),
}
}
@@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
match self {
Console(_, user_info) => &user_info.user,
Link(_, _) => "link",
Link(_) => "link",
}
}
@@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
use BackendType::*;
let res = match self {
@@ -323,17 +323,33 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
"performing authentication using the console"
);
let credentials =
let compute_credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
BackendType::Console(api, credentials)
let mut num_retries = 0;
let mut node =
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
(node, BackendType::Console(api, compute_credentials.info))
}
// NOTE: this auth backend doesn't use client credentials.
Link(url, _) => {
Link(url) => {
info!("performing link authentication");
let info = link::authenticate(ctx, &url, client).await?;
let node_info = link::authenticate(ctx, &url, client).await?;
BackendType::Link(url, info)
(
CachedNodeInfo::new_uncached(node_info),
BackendType::Link(url),
)
}
};
@@ -342,7 +358,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
}
}
impl BackendType<'_, ComputeUserInfo, &()> {
impl BackendType<'_, ComputeUserInfo> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
@@ -350,7 +366,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Link(_, _) => Ok(Cached::new_uncached(None)),
Link(_) => Ok(Cached::new_uncached(None)),
}
}
@@ -361,51 +377,21 @@ impl BackendType<'_, ComputeUserInfo, &()> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
async fn wake_compute(
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
Link(_) => Ok(None),
}
}
}

View File

@@ -4,7 +4,7 @@ use crate::{
compute,
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
metrics::LatencyTimer,
sasl,
stream::{PqStream, Stream},
};
@@ -12,12 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
@@ -27,11 +27,13 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, &mut *ctx);
let scram = auth::Scram(&secret);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
// pause the timer while we communicate with the client
let _paused = latency_timer.pause();
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");

View File

@@ -4,7 +4,7 @@ use super::{
use crate::{
auth::{self, AuthFlow},
console::AuthSecret,
context::RequestMonitoring,
metrics::LatencyTimer,
sasl,
stream::{self, Stream},
};
@@ -16,16 +16,15 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn authenticate_cleartext(
ctx: &mut RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause();
let _paused = latency_timer.pause();
let auth_outcome = AuthFlow::new(client)
.begin(auth::CleartextPassword(secret))
@@ -48,15 +47,14 @@ pub async fn authenticate_cleartext(
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub async fn password_hack_no_authentication(
ctx: &mut RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<ComputeCredentials> {
latency_timer: &mut LatencyTimer,
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause();
let _paused = latency_timer.pause();
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
@@ -73,6 +71,6 @@ pub async fn password_hack_no_authentication(
options: info.options,
endpoint: payload.endpoint,
},
keys: ComputeCredentialKeys::Password(payload.password),
keys: payload.password,
})
}

View File

@@ -61,8 +61,6 @@ pub(super) async fn authenticate(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
ctx.set_auth_method(crate::context::AuthMethod::Web);
// registering waiter can fail if we get unlucky with rng.
// just try again.
let (psql_session_id, waiter) = loop {

View File

@@ -99,9 +99,6 @@ impl ComputeUserInfoMaybeEndpoint {
// record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from));
ctx.set_user(user.clone());
if let Some(dbname) = params.get("database") {
ctx.set_dbname(dbname.into());
}
// Project name might be passed via PG's command-line options.
let endpoint_option = params

View File

@@ -4,11 +4,9 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
sasl, scram,
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -25,7 +23,7 @@ pub trait AuthMethod {
pub struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
pub struct Scram<'a>(pub &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
@@ -140,11 +138,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause();
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
@@ -155,15 +148,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
return Err(super::AuthError::bad_auth_method(sasl.method));
}
match sasl.method {
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
}
_ => {}
}
info!("client chooses {}", sasl.method);
let secret = self.state.0;
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,

View File

@@ -1,8 +1,6 @@
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
@@ -14,7 +12,6 @@ use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::redis::notifications;
use proxy::redis::publisher::RedisPublisherClient;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -25,7 +22,6 @@ use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::info;
@@ -95,6 +91,9 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// timeout for the control plane requests
#[clap(long, default_value = "70s", value_parser = humantime::parse_duration)]
cplane_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
@@ -133,9 +132,6 @@ struct ProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
#[clap(long, default_value_t = 100)]
initial_limit: usize,
@@ -232,19 +228,6 @@ async fn main() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &args.redis_notifications {
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
url,
args.region.clone(),
&config.redis_rps_limit,
)?))),
None => None,
};
let cancellation_handler = Arc::new(CancellationHandler::new(
cancel_map.clone(),
redis_publisher,
));
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
@@ -254,7 +237,6 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
// TODO: rename the argument to something like serverless.
@@ -269,7 +251,6 @@ async fn main() -> anyhow::Result<()> {
serverless_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
}
@@ -293,12 +274,7 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks.spawn(notifications::task_main(
url.to_owned(),
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
@@ -395,7 +371,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
tokio::spawn(locks.garbage_collect_worker(epoch));
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
let endpoint = http::Endpoint::new(
url,
http::new_client(rate_limiter_config, args.cplane_timeout),
);
let api = console::provider::neon::Api::new(endpoint, caches, locks);
let api = console::provider::ConsoleBackend::Console(api);
@@ -410,7 +389,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
}
AuthBackend::Link => {
let url = args.uri.parse()?;
auth::BackendType::Link(MaybeOwned::Owned(url), ())
auth::BackendType::Link(MaybeOwned::Owned(url))
}
};
let http_config = HttpConfig {
@@ -430,8 +409,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let mut redis_rps_limit = args.redis_rps_limit.clone();
RateBucketInfo::validate(&mut redis_rps_limit)?;
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
@@ -443,7 +420,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
endpoint_rps_limit,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
// TODO: add this argument
region: args.region.clone(),

View File

@@ -1,28 +1,16 @@
use async_trait::async_trait;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;
use crate::{
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
redis::publisher::RedisPublisherClient,
};
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
use crate::error::ReportableError;
/// Enables serving `CancelRequest`s.
///
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler {
map: CancelMap,
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
}
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
#[derive(Debug, Error)]
pub enum CancelError {
@@ -44,43 +32,15 @@ impl ReportableError for CancelError {
}
}
impl CancellationHandler {
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
Self { map, redis_client }
}
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
let from = "from_client";
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
if let Some(redis_client) = &self.redis_client {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
info!("publishing cancellation key to Redis");
match redis_client.lock().await.try_publish(key, session_id).await {
Ok(()) => {
info!("cancellation key successfuly published to Redis");
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
}
return Ok(());
};
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
@@ -97,7 +57,7 @@ impl CancellationHandler {
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.map.entry(key) {
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => continue,
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
@@ -109,46 +69,18 @@ impl CancellationHandler {
info!("registered new query cancellation key {key}");
Session {
key,
cancellation_handler: self,
cancel_map: self,
}
}
#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
self.map.contains_key(&session.key)
self.0.contains_key(&session.key)
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
#[async_trait]
pub trait NotificationsCancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
}
#[async_trait]
impl NotificationsCancellationHandler for CancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
let from = "from_redis";
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
match cancel_closure {
Some(cancel_closure) => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
cancel_closure.try_cancel_query().await
}
None => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
tracing::warn!("query cancellation key not found: {key}");
Ok(())
}
}
self.0.is_empty()
}
}
@@ -183,7 +115,7 @@ pub struct Session {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancellation_handler: Arc<CancellationHandler>,
cancel_map: Arc<CancelMap>,
}
impl Session {
@@ -191,9 +123,7 @@ impl Session {
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancellation_handler
.map
.insert(self.key, Some(cancel_closure));
self.cancel_map.0.insert(self.key, Some(cancel_closure));
self.key
}
@@ -201,7 +131,7 @@ impl Session {
impl Drop for Session {
fn drop(&mut self) {
self.cancellation_handler.map.remove(&self.key);
self.cancel_map.0.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
}
}
@@ -212,16 +142,13 @@ mod tests {
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler {
map: CancelMap::default(),
redis_client: None,
});
let cancel_map: Arc<CancelMap> = Default::default();
let session = cancellation_handler.clone().get_session();
assert!(cancellation_handler.contains(&session));
let session = cancel_map.clone().get_session();
assert!(cancel_map.contains(&session));
drop(session);
// Check that the session has been dropped.
assert!(cancellation_handler.is_empty());
assert!(cancel_map.is_empty());
Ok(())
}

View File

@@ -1,7 +1,7 @@
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
console::errors::WakeComputeError,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
@@ -93,7 +93,7 @@ impl ConnCfg {
}
/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: Self) {
pub fn reuse_password(&mut self, other: &Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
@@ -253,8 +253,6 @@ pub struct PostgresConnection {
pub params: std::collections::HashMap<String, String>,
/// Query cancellation token.
pub cancel_closure: CancelClosure,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
_guage: IntCounterPairGuard,
}
@@ -265,7 +263,6 @@ impl ConnCfg {
&self,
ctx: &mut RequestMonitoring,
allow_self_signed_compute: bool,
aux: MetricsAuxInfo,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
@@ -300,7 +297,6 @@ impl ConnCfg {
stream,
params,
cancel_closure,
aux,
_guage: NUM_DB_CONNECTIONS_GAUGE
.with_label_values(&[ctx.protocol])
.guard(),

View File

@@ -13,7 +13,7 @@ use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, (), ()>,
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,
@@ -21,7 +21,6 @@ pub struct ProxyConfig {
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
pub endpoint_rps_limit: Vec<RateBucketInfo>,
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
}

View File

@@ -4,10 +4,7 @@ pub mod neon;
use super::messages::MetricsAuxInfo;
use crate::{
auth::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
auth::{backend::ComputeUserInfo, IpPattern},
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, ProjectInfoCacheOptions},
@@ -264,34 +261,6 @@ pub struct NodeInfo {
pub allow_self_signed_compute: bool,
}
impl NodeInfo {
pub async fn connect(
&self,
ctx: &mut RequestMonitoring,
timeout: Duration,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config
.connect(
ctx,
self.allow_self_signed_compute,
self.aux.clone(),
timeout,
)
.await
}
pub fn reuse_settings(&mut self, other: Self) {
self.allow_self_signed_compute = other.allow_self_signed_compute;
self.config.reuse_password(other.config);
}
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
};
}
}
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;

View File

@@ -176,7 +176,9 @@ impl super::Api for Api {
_ctx: &mut RequestMonitoring,
_user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute().map_ok(Cached::new_uncached).await
self.do_wake_compute()
.map_ok(CachedNodeInfo::new_uncached)
.await
}
}

View File

@@ -11,7 +11,7 @@ use crate::{
console::messages::MetricsAuxInfo,
error::ErrorKind,
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
BranchId, DbName, EndpointId, ProjectId, RoleName,
BranchId, EndpointId, ProjectId, RoleName,
};
pub mod parquet;
@@ -34,11 +34,9 @@ pub struct RequestMonitoring {
project: Option<ProjectId>,
branch: Option<BranchId>,
endpoint_id: Option<EndpointId>,
dbname: Option<DbName>,
user: Option<RoleName>,
application: Option<SmolStr>,
error_kind: Option<ErrorKind>,
pub(crate) auth_method: Option<AuthMethod>,
success: bool,
// extra
@@ -47,15 +45,6 @@ pub struct RequestMonitoring {
pub latency_timer: LatencyTimer,
}
#[derive(Clone, Debug)]
pub enum AuthMethod {
// aka link aka passwordless
Web,
ScramSha256,
ScramSha256Plus,
Cleartext,
}
impl RequestMonitoring {
pub fn new(
session_id: Uuid,
@@ -73,11 +62,9 @@ impl RequestMonitoring {
project: None,
branch: None,
endpoint_id: None,
dbname: None,
user: None,
application: None,
error_kind: None,
auth_method: None,
success: false,
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
@@ -119,18 +106,10 @@ impl RequestMonitoring {
self.application = app.or_else(|| self.application.clone());
}
pub fn set_dbname(&mut self, dbname: DbName) {
self.dbname = Some(dbname);
}
pub fn set_user(&mut self, user: RoleName) {
self.user = Some(user);
}
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
self.auth_method = Some(auth_method);
}
pub fn set_error_kind(&mut self, kind: ErrorKind) {
ERROR_BY_KIND
.with_label_values(&[kind.to_metric_label()])

View File

@@ -84,10 +84,8 @@ struct RequestData {
username: Option<String>,
application_name: Option<String>,
endpoint_id: Option<String>,
database: Option<String>,
project: Option<String>,
branch: Option<String>,
auth_method: Option<&'static str>,
error: Option<&'static str>,
/// Success is counted if we form a HTTP response with sql rows inside
/// Or if we make it to proxy_pass
@@ -106,15 +104,8 @@ impl From<RequestMonitoring> for RequestData {
username: value.user.as_deref().map(String::from),
application_name: value.application.as_deref().map(String::from),
endpoint_id: value.endpoint_id.as_deref().map(String::from),
database: value.dbname.as_deref().map(String::from),
project: value.project.as_deref().map(String::from),
branch: value.branch.as_deref().map(String::from),
auth_method: value.auth_method.as_ref().map(|x| match x {
super::AuthMethod::Web => "web",
super::AuthMethod::ScramSha256 => "scram_sha_256",
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
super::AuthMethod::Cleartext => "cleartext",
}),
protocol: value.protocol,
region: value.region,
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
@@ -440,10 +431,8 @@ mod tests {
application_name: Some("test".to_owned()),
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
auth_method: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
error: None,
@@ -516,15 +505,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1313727, 3, 6000),
(1313720, 3, 6000),
(1313780, 3, 6000),
(1313737, 3, 6000),
(1313867, 3, 6000),
(1313709, 3, 6000),
(1313501, 3, 6000),
(1313737, 3, 6000),
(438118, 1, 2000)
(1087635, 3, 6000),
(1087288, 3, 6000),
(1087444, 3, 6000),
(1087572, 3, 6000),
(1087468, 3, 6000),
(1087500, 3, 6000),
(1087533, 3, 6000),
(1087566, 3, 6000),
(362671, 1, 2000)
],
);
@@ -554,11 +543,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1219459, 5, 10000),
(1225609, 5, 10000),
(1227403, 5, 10000),
(1226765, 5, 10000),
(1218043, 5, 10000)
(1028637, 5, 10000),
(1031969, 5, 10000),
(1019900, 5, 10000),
(1020365, 5, 10000),
(1025010, 5, 10000)
],
);
@@ -590,11 +579,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1205106, 5, 10000),
(1204837, 5, 10000),
(1205130, 5, 10000),
(1205118, 5, 10000),
(1205373, 5, 10000)
(1210770, 6, 12000),
(1211036, 6, 12000),
(1210990, 6, 12000),
(1210861, 6, 12000),
(202073, 1, 2000)
],
);
@@ -619,15 +608,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1313727, 3, 6000),
(1313720, 3, 6000),
(1313780, 3, 6000),
(1313737, 3, 6000),
(1313867, 3, 6000),
(1313709, 3, 6000),
(1313501, 3, 6000),
(1313737, 3, 6000),
(438118, 1, 2000)
(1087635, 3, 6000),
(1087288, 3, 6000),
(1087444, 3, 6000),
(1087572, 3, 6000),
(1087468, 3, 6000),
(1087500, 3, 6000),
(1087533, 3, 6000),
(1087566, 3, 6000),
(362671, 1, 2000)
],
);
@@ -664,7 +653,7 @@ mod tests {
// files are smaller than the size threshold, but they took too long to fill so were flushed early
assert_eq!(
file_stats,
[(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)],
[(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)],
);
tmpdir.close().unwrap();

View File

@@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError {
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
@@ -90,13 +90,3 @@ impl ReportableError for tokio::time::error::Elapsed {
ErrorKind::RateLimit
}
}
impl ReportableError for tokio_postgres::error::Error {
fn get_error_kind(&self) -> ErrorKind {
if self.as_db_error().is_some() {
ErrorKind::Postgres
} else {
ErrorKind::Compute
}
}
}

View File

@@ -19,10 +19,14 @@ use reqwest_middleware::RequestBuilder;
/// This is the preferred way to create new http clients,
/// because it takes care of observability (OpenTelemetry).
/// We deliberately don't want to replace this with a public static.
pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware {
pub fn new_client(
rate_limiter_config: rate_limiter::RateLimiterConfig,
timeout: Duration,
) -> ClientWithMiddleware {
let client = reqwest::ClientBuilder::new()
.dns_resolver(Arc::new(GaiResolver::default()))
.connection_verbose(true)
.timeout(timeout)
.build()
.expect("Failed to create http client");

View File

@@ -152,15 +152,6 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
.unwrap()
});
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_cancellation_requests_total",
"Number of cancellation requests (per found/not_found).",
&["source", "kind"],
)
.unwrap()
});
#[derive(Clone)]
pub struct LatencyTimer {
// time since the stopwatch was started
@@ -209,9 +200,8 @@ impl LatencyTimer {
pub fn success(&mut self) {
// stop the stopwatch and record the time that we have accumulated
if let Some(start) = self.start.take() {
self.accumulated += start.elapsed();
}
let start = self.start.take().expect("latency timer should be started");
self.accumulated += start.elapsed();
// success
self.outcome = "success";

View File

@@ -2,7 +2,6 @@
mod tests;
pub mod connect_compute;
mod copy_bidirectional;
pub mod handshake;
pub mod passthrough;
pub mod retry;
@@ -10,7 +9,7 @@ pub mod wake_compute;
use crate::{
auth,
cancellation::{self, CancellationHandler},
cancellation::{self, CancelMap},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
@@ -62,7 +61,6 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -73,6 +71,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancel_map = Arc::new(CancelMap::default());
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -80,7 +79,7 @@ pub async fn task_main(
let (socket, peer_addr) = accept_result?;
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancel_map = Arc::clone(&cancel_map);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let session_span = info_span!(
@@ -113,7 +112,7 @@ pub async fn task_main(
let res = handle_client(
config,
&mut ctx,
cancellation_handler,
cancel_map,
socket,
ClientMode::Tcp,
endpoint_rate_limiter,
@@ -163,14 +162,14 @@ pub enum ClientMode {
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub fn allow_cleartext(&self) -> bool {
fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
match self {
ClientMode::Tcp => config.allow_self_signed_compute,
ClientMode::Websockets { .. } => false,
@@ -227,7 +226,7 @@ impl ReportableError for ClientRequestError {
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
cancellation_handler: Arc<CancellationHandler>,
cancel_map: Arc<CancelMap>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -253,8 +252,8 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id)
return Ok(cancel_map
.cancel_session(cancel_key_data)
.await
.map(|()| None)?)
}
@@ -287,7 +286,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
}
let user = user_info.get_user().to_owned();
let user_info = match user_info
let (mut node_info, user_info) = match user_info
.authenticate(
ctx,
&mut stream,
@@ -306,16 +305,19 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
}
};
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
let aux = node_info.aux.clone();
let mut node = connect_to_compute(
ctx,
&TcpMechanism { params: &params },
node_info,
&user_info,
mode.allow_self_signed_compute(config),
)
.or_else(|e| stream.throw_error(e))
.await?;
let session = cancellation_handler.get_session();
let session = cancel_map.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
@@ -327,11 +329,10 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
compute: node,
aux,
req: _request_gauge,
conn: _client_gauge,
cancel: session,
}))
}

View File

@@ -1,9 +1,8 @@
use crate::{
auth::backend::ComputeCredentialKeys,
auth,
compute::{self, PostgresConnection},
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
console::{self, errors::WakeComputeError},
context::RequestMonitoring,
error::ReportableError,
metrics::NUM_CONNECTION_FAILURES,
proxy::{
retry::{retry_after, ShouldRetry},
@@ -21,7 +20,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(name = "invalidate_cache", skip_all)]
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
@@ -32,13 +31,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
node_info.invalidate()
node_info.invalidate().config
}
#[async_trait]
pub trait ConnectMechanism {
type Connection;
type ConnectError: ReportableError;
type ConnectError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
@@ -50,16 +49,6 @@ pub trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
pub trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
}
pub struct TcpMechanism<'a> {
/// KV-dictionary with PostgreSQL connection params.
pub params: &'a StartupMessageParams,
@@ -78,7 +67,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
node_info.connect(ctx, timeout).await
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(ctx, allow_self_signed_compute, timeout)
.await
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -89,23 +82,16 @@ impl ConnectMechanism for TcpMechanism<'_> {
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
pub async fn connect_to_compute<M: ConnectMechanism>(
ctx: &mut RequestMonitoring,
mechanism: &M,
user_info: &B,
allow_self_signed_compute: bool,
mut node_info: console::CachedNodeInfo,
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: ShouldRetry + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
if let Some(keys) = user_info.get_keys() {
node_info.set_keys(keys);
}
node_info.allow_self_signed_compute = allow_self_signed_compute;
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -122,31 +108,28 @@ where
error!(error = ?err, "could not connect to compute node");
let node_info =
if err.get_error_kind() == crate::error::ErrorKind::Postgres || !node_info.cached() {
// If the error is Postgres, that means that we managed to connect to the compute node, but there was an error.
// Do not need to retrieve a new node_info, just return the old one.
if !err.should_retry(num_retries) {
return Err(err.into());
}
node_info
} else {
let mut num_retries = 1;
match user_info {
auth::BackendType::Console(api, info) => {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
ctx.latency_timer.cache_miss();
let old_node_info = invalidate_cache(node_info);
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
node_info.reuse_settings(old_node_info);
ctx.latency_timer.cache_miss();
let config = invalidate_cache(node_info);
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
node_info.config.reuse_password(&config);
mechanism.update_connect_config(&mut node_info.config);
node_info
};
}
// nothing to do?
auth::BackendType::Link(_) => {}
};
// now that we have a new node, try connect to it repeatedly.
// this can error for a few reasons, for instance:
// * DNS connection settings haven't quite propagated yet
info!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)

View File

@@ -1,256 +0,0 @@
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
#[derive(Debug)]
enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}
pub(super) async fn copy_bidirectional<A, B>(
a: &mut A,
b: &mut B,
) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut a_to_b = TransferState::Running(CopyBuffer::new());
let mut b_to_a = TransferState::Running(CopyBuffer::new());
poll_fn(|cx| {
let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
// Early termination checks
if let TransferState::Done(_) = a_to_b {
if let TransferState::Running(buf) = &b_to_a {
// Initiate shutdown
b_to_a = TransferState::ShuttingDown(buf.amt);
b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
}
}
if let TransferState::Done(_) = b_to_a {
if let TransferState::Running(buf) = &a_to_b {
// Initiate shutdown
a_to_b = TransferState::ShuttingDown(buf.amt);
a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
}
}
// It is not a problem if ready! returns early ... (comment remains the same)
let a_to_b = ready!(a_to_b_result);
let b_to_a = ready!(b_to_a_result);
Poll::Ready(Ok((a_to_b, b_to_a)))
})
.await
}
#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
impl CopyBuffer {
pub(super) fn new() -> Self {
Self {
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
}
}
fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
buf.set_filled(me.cap);
let res = reader.poll_read(cx, &mut buf);
if let Poll::Ready(Ok(())) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
}
fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
}
Poll::Pending
}
res => res,
}
}
pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
self.pos = 0;
self.cap = 0;
match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
self.need_flush = false;
}
return Poll::Pending;
}
}
}
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
}
// If pos larger than cap, this loop will never stop.
// In particular, user's wrong poll_write implementation returning
// incorrect written length may lead to thread blocking.
debug_assert!(
self.pos <= self.cap,
"writer returned length larger than input slice"
);
// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_early_termination_a_to_d() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
a_mock.write_all(b"hello").await.unwrap();
a_mock.shutdown().await.unwrap();
d_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(a_to_d_count, 5); // 'hello' was transferred
assert!(d_to_a_count <= 8); // response only partially transferred or not at all
}
#[tokio::test]
async fn test_early_termination_d_to_a() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
d_mock.write_all(b"hello").await.unwrap();
d_mock.shutdown().await.unwrap();
a_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(d_to_a_count, 5); // 'hello' was transferred
assert!(a_to_d_count <= 8); // response only partially transferred or not at all
}
}

View File

@@ -1,5 +1,4 @@
use crate::{
cancellation,
compute::PostgresConnection,
console::messages::MetricsAuxInfo,
metrics::NUM_BYTES_PROXIED_COUNTER,
@@ -46,7 +45,7 @@ pub async fn proxy_pass(
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?;
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}
@@ -58,7 +57,6 @@ pub struct ProxyPassthrough<S> {
pub req: IntCounterPairGuard,
pub conn: IntCounterPairGuard,
pub cancel: cancellation::Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {

View File

@@ -2,19 +2,13 @@
mod mitm;
use std::time::Duration;
use super::connect_compute::ConnectMechanism;
use super::retry::ShouldRetry;
use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
};
use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend};
use crate::config::CertResolver;
use crate::console::caches::NodeInfoCache;
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
@@ -150,7 +144,7 @@ impl TestAuth for Scram {
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
let outcome = auth::AuthFlow::new(stream)
.begin(auth::Scram(&self.0, &mut RequestMonitoring::test()))
.begin(auth::Scram(&self.0))
.await?
.authenticate()
.await?;
@@ -375,15 +369,12 @@ enum ConnectAction {
Connect,
Retry,
Fail,
RetryPg,
FailPg,
}
#[derive(Clone)]
struct TestConnectMechanism {
counter: Arc<std::sync::Mutex<usize>>,
sequence: Vec<ConnectAction>,
cache: &'static NodeInfoCache,
}
impl TestConnectMechanism {
@@ -402,12 +393,6 @@ impl TestConnectMechanism {
Self {
counter: Arc::new(std::sync::Mutex::new(0)),
sequence,
cache: Box::leak(Box::new(NodeInfoCache::new(
"test",
1,
Duration::from_secs(100),
false,
))),
}
}
}
@@ -418,13 +403,6 @@ struct TestConnection;
#[derive(Debug)]
struct TestConnectError {
retryable: bool,
kind: crate::error::ErrorKind,
}
impl ReportableError for TestConnectError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
self.kind
}
}
impl std::fmt::Display for TestConnectError {
@@ -458,22 +436,8 @@ impl ConnectMechanism for TestConnectMechanism {
*counter += 1;
match action {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(TestConnectError {
retryable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::Fail => Err(TestConnectError {
retryable: false,
kind: ErrorKind::Compute,
}),
ConnectAction::FailPg => Err(TestConnectError {
retryable: false,
kind: ErrorKind::Postgres,
}),
ConnectAction::RetryPg => Err(TestConnectError {
retryable: true,
kind: ErrorKind::Postgres,
}),
ConnectAction::Retry => Err(TestConnectError { retryable: true }),
ConnectAction::Fail => Err(TestConnectError { retryable: false }),
x => panic!("expecting action {:?}, connect is called instead", x),
}
}
@@ -487,7 +451,7 @@ impl TestBackend for TestConnectMechanism {
let action = self.sequence[*counter];
*counter += 1;
match action {
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
ConnectAction::Wake => Ok(helper_create_cached_node_info()),
ConnectAction::WakeFail => {
let err = console::errors::ApiError::Console {
status: http::StatusCode::FORBIDDEN,
@@ -519,41 +483,37 @@ impl TestBackend for TestConnectMechanism {
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
fn helper_create_cached_node_info() -> CachedNodeInfo {
let node = NodeInfo {
config: compute::ConnCfg::new(),
aux: Default::default(),
allow_self_signed_compute: false,
};
let (_, node) = cache.insert("key".into(), node);
node
CachedNodeInfo::new_uncached(node)
}
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::BackendType<'static, ComputeCredentials, &()> {
) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) {
let cache = helper_create_cached_node_info();
let user_info = auth::BackendType::Console(
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
);
user_info
(cache, user_info)
}
#[tokio::test]
async fn connect_to_compute_success() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
let mechanism = TestConnectMechanism::new(vec![Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
@@ -561,52 +521,24 @@ async fn connect_to_compute_success() {
#[tokio::test]
async fn connect_to_compute_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
}
#[tokio::test]
async fn connect_to_compute_retry_pg() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, RetryPg, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap();
mechanism.verify();
}
#[tokio::test]
async fn connect_to_compute_fail_pg() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, FailPg]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
.await
.unwrap_err();
mechanism.verify();
}
/// Test that we don't retry if the error is not retryable.
#[tokio::test]
async fn connect_to_compute_non_retry_1() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap_err();
mechanism.verify();
@@ -615,12 +547,11 @@ async fn connect_to_compute_non_retry_1() {
/// Even for non-retryable errors, we should retry at least once.
#[tokio::test]
async fn connect_to_compute_non_retry_2() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
@@ -629,16 +560,15 @@ async fn connect_to_compute_non_retry_2() {
/// Retry for at most `NUM_RETRIES_CONNECT` times.
#[tokio::test]
async fn connect_to_compute_non_retry_3() {
let _ = env_logger::try_init();
assert_eq!(NUM_RETRIES_CONNECT, 16);
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![
Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap_err();
mechanism.verify();
@@ -647,12 +577,11 @@ async fn connect_to_compute_non_retry_3() {
/// Should retry wake compute.
#[tokio::test]
async fn wake_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
@@ -661,12 +590,11 @@ async fn wake_retry() {
/// Wake failed with a non-retryable error.
#[tokio::test]
async fn wake_non_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -1,4 +1,9 @@
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
use crate::auth::backend::ComputeUserInfo;
use crate::console::{
errors::WakeComputeError,
provider::{CachedNodeInfo, ConsoleBackend},
Api,
};
use crate::context::RequestMonitoring;
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
use crate::proxy::retry::retry_after;
@@ -6,16 +11,17 @@ use hyper::StatusCode;
use std::ops::ControlFlow;
use tracing::{error, warn};
use super::connect_compute::ComputeConnectBackend;
use super::retry::ShouldRetry;
pub async fn wake_compute<B: ComputeConnectBackend>(
/// wake a compute (or retrieve an existing compute session from cache)
pub async fn wake_compute(
num_retries: &mut u32,
ctx: &mut RequestMonitoring,
api: &B,
api: &ConsoleBackend,
info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
loop {
let wake_res = api.wake_compute(ctx).await;
let wake_res = api.wake_compute(ctx, info).await;
match handle_try_wake(wake_res, *num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");

View File

@@ -4,4 +4,4 @@ mod limiter;
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::Limiter;
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
pub use limiter::{EndpointRateLimiter, RateBucketInfo};

View File

@@ -22,44 +22,6 @@ use super::{
RateLimiterConfig,
};
pub struct RedisRateLimiter {
data: Vec<RateBucket>,
info: &'static [RateBucketInfo],
}
impl RedisRateLimiter {
pub fn new(info: &'static [RateBucketInfo]) -> Self {
Self {
data: vec![
RateBucket {
start: Instant::now(),
count: 0,
};
info.len()
],
info,
}
}
/// Check that number of connections is below `max_rps` rps.
pub fn check(&mut self) -> bool {
let now = Instant::now();
let should_allow_request = self
.data
.iter_mut()
.zip(self.info)
.all(|(bucket, info)| bucket.should_allow_request(info, now));
if should_allow_request {
// only increment the bucket counts if the request will actually be accepted
self.data.iter_mut().for_each(RateBucket::inc);
}
should_allow_request
}
}
// Simple per-endpoint rate limiter.
//
// Check that number of connections to the endpoint is below `max_rps` rps.

View File

@@ -1,2 +1 @@
pub mod notifications;
pub mod publisher;

View File

@@ -1,44 +1,38 @@
use std::{convert::Infallible, sync::Arc};
use futures::StreamExt;
use pq_proto::CancelKeyData;
use redis::aio::PubSub;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use serde::Deserialize;
use crate::{
cache::project_info::ProjectInfoCache,
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
intern::{ProjectIdInt, RoleNameInt},
};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
struct RedisConsumerClient {
struct ConsoleRedisClient {
client: redis::Client,
}
impl RedisConsumerClient {
impl ConsoleRedisClient {
pub fn new(url: &str) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self { client })
}
async fn try_connect(&self) -> anyhow::Result<PubSub> {
let mut conn = self.client.get_async_connection().await?.into_pubsub();
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
conn.subscribe(PROXY_CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
conn.subscribe(CHANNEL_NAME).await?;
Ok(conn)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(tag = "topic", content = "data")]
pub(crate) enum Notification {
enum Notification {
#[serde(
rename = "/allowed_ips_updated",
deserialize_with = "deserialize_json_string"
@@ -51,25 +45,16 @@ pub(crate) enum Notification {
deserialize_with = "deserialize_json_string"
)]
PasswordUpdate { password_update: PasswordUpdate },
#[serde(rename = "/cancel_session")]
Cancel(CancelSession),
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedIpsUpdate {
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct PasswordUpdate {
project_id: ProjectIdInt,
role_name: RoleNameInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct CancelSession {
pub region_id: Option<String>,
pub cancel_key_data: CancelKeyData,
pub session_id: Uuid,
}
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
T: for<'de2> serde::Deserialize<'de2>,
@@ -79,88 +64,6 @@ where
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
}
struct MessageHandler<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> {
cache: Arc<C>,
cancellation_handler: Arc<H>,
region_id: String,
}
impl<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> MessageHandler<C, H>
{
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
Self {
cache,
cancellation_handler,
region_id,
}
}
pub fn disable_ttl(&self) {
self.cache.disable_ttl();
}
pub fn enable_ttl(&self) {
self.cache.enable_ttl();
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
use Notification::*;
let payload: String = msg.get_payload()?;
tracing::debug!(?payload, "received a message payload");
let msg: Notification = match serde_json::from_str(&payload) {
Ok(msg) => msg,
Err(e) => {
tracing::error!("broken message: {e}");
return Ok(());
}
};
tracing::debug!(?msg, "received a message");
match msg {
Cancel(cancel_session) => {
tracing::Span::current().record(
"session_id",
&tracing::field::display(cancel_session.session_id),
);
if let Some(cancel_region) = cancel_session.region_id {
// If the message is not for this region, ignore it.
if cancel_region != self.region_id {
return Ok(());
}
}
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
match self
.cancellation_handler
.cancel_session_no_publish(cancel_session.cancel_key_data)
.await
{
Ok(()) => {}
Err(e) => {
tracing::error!("failed to cancel session: {e}");
}
}
}
_ => {
invalidate_cache(self.cache.clone(), msg.clone());
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
let cache = self.cache.clone();
tokio::spawn(async move {
tokio::time::sleep(INVALIDATION_LAG).await;
invalidate_cache(cache, msg);
});
}
}
Ok(())
}
}
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
use Notification::*;
match msg {
@@ -171,33 +74,50 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
password_update.project_id,
password_update.role_name,
),
Cancel(_) => unreachable!("cancel message should be handled separately"),
}
}
#[tracing::instrument(skip(cache))]
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let payload: String = msg.get_payload()?;
tracing::debug!(?payload, "received a message payload");
let msg: Notification = match serde_json::from_str(&payload) {
Ok(msg) => msg,
Err(e) => {
tracing::error!("broken message: {e}");
return Ok(());
}
};
tracing::debug!(?msg, "received a message");
invalidate_cache(cache.clone(), msg.clone());
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
tokio::spawn(async move {
tokio::time::sleep(INVALIDATION_LAG).await;
invalidate_cache(cache, msg.clone());
});
Ok(())
}
/// Handle console's invalidation messages.
#[tracing::instrument(name = "console_notifications", skip_all)]
pub async fn task_main<C>(
url: String,
cache: Arc<C>,
cancel_map: CancelMap,
region_id: String,
) -> anyhow::Result<Infallible>
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
cache.enable_ttl();
let handler = MessageHandler::new(
cache,
Arc::new(CancellationHandler::new(cancel_map, None)),
region_id,
);
loop {
let redis = RedisConsumerClient::new(&url)?;
let redis = ConsoleRedisClient::new(&url)?;
let conn = match redis.try_connect().await {
Ok(conn) => {
handler.disable_ttl();
cache.disable_ttl();
conn
}
Err(e) => {
@@ -210,7 +130,7 @@ where
};
let mut stream = conn.into_on_message();
while let Some(msg) = stream.next().await {
match handler.handle_message(msg).await {
match handle_message(msg, cache.clone()) {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to handle message: {e}, will try to reconnect");
@@ -218,7 +138,7 @@ where
}
}
}
handler.enable_ttl();
cache.enable_ttl();
}
}
@@ -278,33 +198,6 @@ mod tests {
}
);
Ok(())
}
#[test]
fn parse_cancel_session() -> anyhow::Result<()> {
let cancel_key_data = CancelKeyData {
backend_pid: 42,
cancel_key: 41,
};
let uuid = uuid::Uuid::new_v4();
let msg = Notification::Cancel(CancelSession {
cancel_key_data,
region_id: None,
session_id: uuid,
});
let text = serde_json::to_string(&msg)?;
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(msg, result);
let msg = Notification::Cancel(CancelSession {
cancel_key_data,
region_id: Some("region".to_string()),
session_id: uuid,
});
let text = serde_json::to_string(&msg)?;
let result: Notification = serde_json::from_str(&text)?;
assert_eq!(msg, result,);
Ok(())
}
}

View File

@@ -1,80 +0,0 @@
use pq_proto::CancelKeyData;
use redis::AsyncCommands;
use uuid::Uuid;
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
pub struct RedisPublisherClient {
client: redis::Client,
publisher: Option<redis::aio::Connection>,
region_id: String,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
pub fn new(
url: &str,
region_id: String,
info: &'static [RateBucketInfo],
) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self {
client,
publisher: None,
region_id,
limiter: RedisRateLimiter::new(info),
})
}
pub async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping cancellation message");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.publish(cancel_key_data, session_id).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to publish a message: {e}");
self.publisher = None;
}
}
tracing::info!("Publisher is disconnected. Reconnectiong...");
self.try_connect().await?;
self.publish(cancel_key_data, session_id).await
}
async fn publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
let conn = self
.publisher
.as_mut()
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
region_id: Some(self.region_id.clone()),
cancel_key_data,
session_id,
}))?;
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
Ok(())
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.get_async_connection().await {
Ok(conn) => {
self.publisher = Some(conn);
}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e.into());
}
}
Ok(())
}
}

View File

@@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
server::{
@@ -50,7 +50,6 @@ pub async fn task_main(
ws_listener: TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -116,7 +115,7 @@ pub async fn task_main(
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
@@ -128,9 +127,9 @@ pub async fn task_main(
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
request_handler(
@@ -138,7 +137,7 @@ pub async fn task_main(
config,
backend,
ws_connections,
cancellation_handler,
cancel_map,
session_id,
peer_addr.ip(),
endpoint_rate_limiter,
@@ -206,7 +205,7 @@ async fn request_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -233,7 +232,7 @@ async fn request_handler(
config,
ctx,
websocket,
cancellation_handler,
cancel_map,
host,
endpoint_rate_limiter,
)

View File

@@ -1,10 +1,10 @@
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use tracing::{field::display, info};
use tracing::info;
use crate::{
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
compute,
config::ProxyConfig,
console::{
@@ -15,7 +15,7 @@ use crate::{
proxy::connect_compute::ConnectMechanism,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -27,7 +27,7 @@ impl PoolingBackend {
&self,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<ComputeCredentials, AuthError> {
) -> Result<ComputeCredentialKeys, AuthError> {
let user_info = conn_info.user_info.clone();
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
@@ -49,17 +49,13 @@ impl PoolingBackend {
};
let auth_outcome =
crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
let res = match auth_outcome {
match auth_outcome {
crate::sasl::Outcome::Success(key) => Ok(key),
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
Err(AuthError::auth_failed(&*conn_info.user_info.user))
}
};
res.map(|key| ComputeCredentials {
info: user_info,
keys: key,
})
}
}
// Wake up the destination if needed. Code here is a bit involved because
@@ -70,7 +66,7 @@ impl PoolingBackend {
&self,
ctx: &mut RequestMonitoring,
conn_info: ConnInfo,
keys: ComputeCredentials,
keys: ComputeCredentialKeys,
force_new: bool,
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
let maybe_client = if !force_new {
@@ -85,9 +81,27 @@ impl PoolingBackend {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.config.auth_backend.as_ref().map(|_| keys);
ctx.set_application(Some(APP_NAME));
let backend = self
.config
.auth_backend
.as_ref()
.map(|_| conn_info.user_info.clone());
let mut node_info = backend
.wake_compute(ctx)
.await?
.ok_or(HttpConnError::NoComputeInfo)?;
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
};
ctx.set_project(node_info.aux.clone());
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
@@ -95,8 +109,8 @@ impl PoolingBackend {
conn_info,
pool: self.pool.clone(),
},
node_info,
&backend,
false, // do not allow self signed compute for http flow
)
.await
}
@@ -115,6 +129,8 @@ pub enum HttpConnError {
AuthError(#[from] AuthError),
#[error("wake_compute returned error")]
WakeCompute(#[from] WakeComputeError),
#[error("wake_compute returned nothing")]
NoComputeInfo,
}
struct TokioMechanism {

View File

@@ -4,6 +4,7 @@ use metrics::IntCounterPairGuard;
use parking_lot::RwLock;
use rand::Rng;
use smallvec::SmallVec;
use smol_str::SmolStr;
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use std::{
fmt,
@@ -30,6 +31,8 @@ use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
#[derive(Debug, Clone)]
pub struct ConnInfo {
pub user_info: ComputeUserInfo,
@@ -376,13 +379,12 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
} else {
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
info!("pool: reusing connection '{conn_info}'");
client.session.send(ctx.session_id)?;
tracing::Span::current().record(
"pid",
&tracing::field::display(client.inner.get_process_id()),
);
info!("pool: reusing connection '{conn_info}'");
client.session.send(ctx.session_id)?;
ctx.latency_timer.pool_hit();
ctx.latency_timer.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
@@ -575,6 +577,7 @@ pub struct Client<C: ClientInnerExt> {
}
pub struct Discard<'a, C: ClientInnerExt> {
conn_id: uuid::Uuid,
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
@@ -600,7 +603,14 @@ impl<C: ClientInnerExt> Client<C> {
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { pool, conn_info })
(
&mut inner.inner,
Discard {
pool,
conn_info,
conn_id: inner.conn_id,
},
)
}
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
@@ -615,13 +625,13 @@ impl<C: ClientInnerExt> Discard<'_, C> {
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
}
}
pub fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
}
}
}

View File

@@ -36,8 +36,6 @@ use crate::error::ReportableError;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::DbName;
use crate::RoleName;
use super::backend::PoolingBackend;
@@ -119,9 +117,6 @@ fn get_conn_info(
headers: &HeaderMap,
tls: &TlsConfig,
) -> Result<ConnInfo, ConnInfoError> {
// HTTP only uses cleartext (for now and likely always)
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let connection_string = headers
.get("Neon-Connection-String")
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
@@ -139,8 +134,7 @@ fn get_conn_info(
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
ctx.set_dbname(dbname.clone());
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
@@ -180,7 +174,7 @@ fn get_conn_info(
Ok(ConnInfo {
user_info,
dbname,
dbname: dbname.into(),
password: match password {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
@@ -306,14 +300,7 @@ pub async fn handle(
Ok(response)
}
#[instrument(
name = "sql-over-http",
skip_all,
fields(
pid = tracing::field::Empty,
conn_id = tracing::field::Empty
)
)]
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
async fn handle_inner(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
@@ -367,10 +354,12 @@ async fn handle_inner(
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
let paused = ctx.latency_timer.pause();
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,
None => MAX_REQUEST_SIZE + 1,
};
drop(paused);
info!(request_content_length, "request size in bytes");
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
@@ -386,20 +375,15 @@ async fn handle_inner(
let body = hyper::body::to_bytes(request.into_body())
.await
.map_err(anyhow::Error::from)?;
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
};
let authenticate_and_connect = async {
let keys = backend.authenticate(ctx, &conn_info).await?;
let client = backend
backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
Ok::<_, HttpConnError>(client)
.await
};
// Run both operations in parallel
@@ -431,7 +415,6 @@ async fn handle_inner(
results
}
Payload::Batch(statements) => {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
@@ -461,7 +444,6 @@ async fn handle_inner(
.await
{
Ok(results) => {
info!("commit");
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
@@ -472,7 +454,6 @@ async fn handle_inner(
results
}
Err(err) => {
info!("rollback");
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
@@ -547,10 +528,8 @@ async fn query_to_json<T: GenericClient>(
raw_output: bool,
default_array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
info!("executing query");
let query_params = data.params;
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
info!("finished executing query");
// Manually drain the stream into a vector to leave row_stream hanging
// around to get a command tag. Also check that the response is not too
@@ -585,13 +564,6 @@ async fn query_to_json<T: GenericClient>(
}
.and_then(|s| s.parse::<i64>().ok());
info!(
rows = rows.len(),
?ready,
command_tag,
"finished reading rows"
);
let mut fields = vec![];
let mut columns = vec![];

View File

@@ -1,5 +1,5 @@
use crate::{
cancellation::CancellationHandler,
cancellation::CancelMap,
config::ProxyConfig,
context::RequestMonitoring,
error::{io_error, ReportableError},
@@ -133,7 +133,7 @@ pub async fn serve_websocket(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
cancellation_handler: Arc<CancellationHandler>,
cancel_map: Arc<CancelMap>,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
@@ -141,7 +141,7 @@ pub async fn serve_websocket(
let res = handle_client(
config,
&mut ctx,
cancellation_handler,
cancel_map,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,

View File

@@ -237,6 +237,7 @@ mod tests {
use std::{
net::TcpListener,
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::Error;
@@ -279,7 +280,7 @@ mod tests {
tokio::spawn(server);
let metrics = Metrics::default();
let client = http::new_client(RateLimiterConfig::default());
let client = http::new_client(RateLimiterConfig::default(), Duration::from_secs(15));
let endpoint = Url::parse(&format!("http://{addr}")).unwrap();
let now = Utc::now();

View File

@@ -61,10 +61,3 @@ tokio-stream.workspace = true
utils.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
walproposer.workspace = true
rand.workspace = true
desim.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }

View File

@@ -1,155 +0,0 @@
use std::sync::Arc;
use tracing::{info, warn};
use utils::lsn::Lsn;
use crate::walproposer_sim::{
log::{init_logger, init_tracing_logger},
simulation::{generate_network_opts, generate_schedule, Schedule, TestAction, TestConfig},
};
pub mod walproposer_sim;
// Test that simulation supports restarting (crashing) safekeepers.
#[test]
fn crash_safekeeper() {
let clock = init_logger();
let config = TestConfig::new(Some(clock));
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
// Write some WAL and crash safekeeper 0 without waiting for replication.
test.poll_for_duration(30);
wp.write_tx(3);
test.servers[0].restart();
// Wait some time, so that walproposer can reconnect.
test.poll_for_duration(2000);
}
// Test that walproposer can be crashed (stopped).
#[test]
fn test_simple_restart() {
let clock = init_logger();
let config = TestConfig::new(Some(clock));
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
test.poll_for_duration(30);
wp.write_tx(3);
test.poll_for_duration(100);
wp.stop();
drop(wp);
let lsn = test.sync_safekeepers().unwrap();
info!("Sucessfully synced safekeepers at {}", lsn);
}
// Test runnning a simple schedule, restarting everything a several times.
#[test]
fn test_simple_schedule() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
config.network.keepalive_timeout = Some(100);
let test = config.start(1337);
let schedule: Schedule = vec![
(0, TestAction::RestartWalProposer),
(50, TestAction::WriteTx(5)),
(100, TestAction::RestartSafekeeper(0)),
(100, TestAction::WriteTx(5)),
(110, TestAction::RestartSafekeeper(1)),
(110, TestAction::WriteTx(5)),
(120, TestAction::RestartSafekeeper(2)),
(120, TestAction::WriteTx(5)),
(201, TestAction::RestartWalProposer),
(251, TestAction::RestartSafekeeper(0)),
(251, TestAction::RestartSafekeeper(1)),
(251, TestAction::RestartSafekeeper(2)),
(251, TestAction::WriteTx(5)),
(255, TestAction::WriteTx(5)),
(1000, TestAction::WriteTx(5)),
];
test.run_schedule(&schedule)?;
info!("Test finished, stopping all threads");
test.world.deallocate();
Ok(())
}
// Test that simulation can process 10^4 transactions.
#[test]
fn test_many_tx() -> anyhow::Result<()> {
let clock = init_logger();
let config = TestConfig::new(Some(clock));
let test = config.start(1337);
let mut schedule: Schedule = vec![];
for i in 0..100 {
schedule.push((i * 10, TestAction::WriteTx(100)));
}
test.run_schedule(&schedule)?;
info!("Test finished, stopping all threads");
test.world.stop_all();
let events = test.world.take_events();
info!("Events: {:?}", events);
let last_commit_lsn = events
.iter()
.filter_map(|event| {
if event.data.starts_with("commit_lsn;") {
let lsn: u64 = event.data.split(';').nth(1).unwrap().parse().unwrap();
return Some(lsn);
}
None
})
.last()
.unwrap();
let initdb_lsn = 21623024;
let diff = last_commit_lsn - initdb_lsn;
info!("Last commit lsn: {}, diff: {}", last_commit_lsn, diff);
// each tx is at least 8 bytes, it's written a 100 times for in a loop for 100 times
assert!(diff > 100 * 100 * 8);
Ok(())
}
// Checks that we don't have nasty circular dependencies, preventing Arc from deallocating.
// This test doesn't really assert anything, you need to run it manually to check if there
// is any issue.
#[test]
fn test_res_dealloc() -> anyhow::Result<()> {
let clock = init_tracing_logger(true);
let mut config = TestConfig::new(Some(clock));
let seed = 123456;
config.network = generate_network_opts(seed);
let test = config.start(seed);
warn!("Running test with seed {}", seed);
let schedule = generate_schedule(seed);
info!("schedule: {:?}", schedule);
test.run_schedule(&schedule).unwrap();
test.world.stop_all();
let world = test.world.clone();
drop(test);
info!("world strong count: {}", Arc::strong_count(&world));
world.deallocate();
info!("world strong count: {}", Arc::strong_count(&world));
Ok(())
}

Some files were not shown because too many files have changed in this diff Show More