Compare commits

..

5 Commits

Author SHA1 Message Date
Heikki Linnakangas
a77dd0700c Move startup tracing context handling 2024-05-03 09:28:19 +03:00
Em Sharnoff
c9472434c9 compute_ctl: Break up main() into discrete phases
This commit is intentionally designed to have as small a diff as
possible. To that end, the basic idea is that each distinct "chunk" of
the previous main() has been wrapped in its own function, with the
return values from each function being passed directly into the next.

The structure of main() is now visible from its contents:

  1. init()
  2. process_cli()
  3. wait_spec()
  4. start_postgres()
  5. wait_postgres()
  6. cleanup_and_exit()

There's a lot of other work that can / should(?) be done beyond this,
but I figure that's more opinionated, and this should be a solid start.
2024-05-01 12:07:46 -07:00
Em Sharnoff
f9c2945f74 compute_ctl: Non-functional prep changes to reduce diff
A couple lines moved further down in main(), and one case of using
Option<&str> instead of Option<&String>.
2024-05-01 12:06:55 -07:00
Alex Chi Z
5558457c84 chore(pageserver): categorize basebackup errors (#7523)
close https://github.com/neondatabase/neon/issues/7391

## Summary of changes

Categorize basebackup error into two types: server error and client
error. This makes it easier to set up alerts.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2024-05-01 16:31:59 +00:00
Alex Chi Z
26e6ff8ba6 chore(pageserver): concise error message for layer traversal (#7565)
Instead of showing the full path of layer traversal, we now only show
tenant (in tracing context)+timeline+filename.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2024-05-01 11:44:42 -04:00
9 changed files with 670 additions and 493 deletions

17
Cargo.lock generated
View File

@@ -3938,16 +3938,6 @@ dependencies = [
"siphasher",
]
[[package]]
name = "pin-list"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe91484d5a948b56f858ff2b92fd5b20b97d21b11d2d41041db8e5ec12d56c5e"
dependencies = [
"pin-project-lite",
"pinned-aliasable",
]
[[package]]
name = "pin-project"
version = "1.1.0"
@@ -3980,12 +3970,6 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pinned-aliasable"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d0f9ae89bf0ed03b69ac1f3f7ea2e6e09b4fa5448011df2e67d581c2b850b7b"
[[package]]
name = "pkcs8"
version = "0.9.0"
@@ -4365,7 +4349,6 @@ dependencies = [
"parquet",
"parquet_derive",
"pbkdf2",
"pin-list",
"pin-project-lite",
"postgres-native-tls",
"postgres-protocol",

View File

@@ -51,6 +51,7 @@ use tracing::{error, info};
use url::Url;
use compute_api::responses::ComputeStatus;
use compute_api::spec::ComputeSpec;
use compute_tools::compute::{
forward_termination_signal, ComputeNode, ComputeState, ParsedSpec, PG_PID,
@@ -68,6 +69,29 @@ use compute_tools::spec::*;
const BUILD_TAG_DEFAULT: &str = "latest";
fn main() -> Result<()> {
let (build_tag, clap_args) = init()?;
let (pg_handle, start_pg_result) =
{
// Enter startup tracing context
let _startup_context_guard = startup_context_from_env();
let cli_result = process_cli(&clap_args)?;
let wait_spec_result = wait_spec(build_tag, cli_result)?;
start_postgres(&clap_args, wait_spec_result)?
// Startup is finished, exit the startup tracing context
};
// PostgreSQL is now running, if startup was successful. Wait until it exits.
let wait_pg_result = wait_postgres(pg_handle)?;
cleanup_and_exit(start_pg_result, wait_pg_result)
}
fn init() -> Result<(String, clap::ArgMatches)> {
init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?;
@@ -82,35 +106,11 @@ fn main() -> Result<()> {
.to_string();
info!("build_tag: {build_tag}");
let matches = cli().get_matches();
let pgbin_default = String::from("postgres");
let pgbin = matches.get_one::<String>("pgbin").unwrap_or(&pgbin_default);
let ext_remote_storage = matches
.get_one::<String>("remote-ext-config")
// Compatibility hack: if the control plane specified any remote-ext-config
// use the default value for extension storage proxy gateway.
// Remove this once the control plane is updated to pass the gateway URL
.map(|conf| {
if conf.starts_with("http") {
conf.trim_end_matches('/')
} else {
"http://pg-ext-s3-gateway"
}
});
let http_port = *matches
.get_one::<u16>("http-port")
.expect("http-port is required");
let pgdata = matches
.get_one::<String>("pgdata")
.expect("PGDATA path is required");
let connstr = matches
.get_one::<String>("connstr")
.expect("Postgres connection string is required");
let spec_json = matches.get_one::<String>("spec");
let spec_path = matches.get_one::<String>("spec-path");
Ok((build_tag, cli().get_matches()))
}
fn startup_context_from_env() -> Option<opentelemetry::ContextGuard>
{
// Extract OpenTelemetry context for the startup actions from the
// TRACEPARENT and TRACESTATE env variables, and attach it to the current
// tracing context.
@@ -147,7 +147,7 @@ fn main() -> Result<()> {
if let Ok(val) = std::env::var("TRACESTATE") {
startup_tracing_carrier.insert("tracestate".to_string(), val);
}
let startup_context_guard = if !startup_tracing_carrier.is_empty() {
if !startup_tracing_carrier.is_empty() {
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry::sdk::propagation::TraceContextPropagator;
let guard = TraceContextPropagator::new()
@@ -157,8 +157,42 @@ fn main() -> Result<()> {
Some(guard)
} else {
None
};
}
}
fn process_cli(
matches: &clap::ArgMatches,
) -> Result<ProcessCliResult> {
let pgbin_default = "postgres";
let pgbin = matches
.get_one::<String>("pgbin")
.map(|s| s.as_str())
.unwrap_or(pgbin_default);
let ext_remote_storage = matches
.get_one::<String>("remote-ext-config")
// Compatibility hack: if the control plane specified any remote-ext-config
// use the default value for extension storage proxy gateway.
// Remove this once the control plane is updated to pass the gateway URL
.map(|conf| {
if conf.starts_with("http") {
conf.trim_end_matches('/')
} else {
"http://pg-ext-s3-gateway"
}
});
let http_port = *matches
.get_one::<u16>("http-port")
.expect("http-port is required");
let pgdata = matches
.get_one::<String>("pgdata")
.expect("PGDATA path is required");
let connstr = matches
.get_one::<String>("connstr")
.expect("Postgres connection string is required");
let spec_json = matches.get_one::<String>("spec");
let spec_path = matches.get_one::<String>("spec-path");
let compute_id = matches.get_one::<String>("compute-id");
let control_plane_uri = matches.get_one::<String>("control-plane-uri");
@@ -199,6 +233,45 @@ fn main() -> Result<()> {
}
};
let result = ProcessCliResult {
// directly from CLI:
connstr,
pgdata,
pgbin,
ext_remote_storage,
http_port,
// others:
spec,
live_config_allowed,
};
Ok(result)
}
struct ProcessCliResult<'clap> {
connstr: &'clap str,
pgdata: &'clap str,
pgbin: &'clap str,
ext_remote_storage: Option<&'clap str>,
http_port: u16,
/// If a spec was provided via CLI or file, the [`ComputeSpec`]
spec: Option<ComputeSpec>,
live_config_allowed: bool,
}
fn wait_spec(
build_tag: String,
ProcessCliResult {
connstr,
pgdata,
pgbin,
ext_remote_storage,
http_port,
spec,
live_config_allowed,
}: ProcessCliResult,
) -> Result<WaitSpecResult> {
let mut new_state = ComputeState::new();
let spec_set;
@@ -237,8 +310,6 @@ fn main() -> Result<()> {
let _http_handle =
launch_http_server(http_port, &compute).expect("cannot launch http endpoint thread");
let extension_server_port: u16 = http_port;
if !spec_set {
// No spec provided, hang waiting for it.
info!("no compute spec provided, waiting");
@@ -255,6 +326,19 @@ fn main() -> Result<()> {
}
}
Ok(WaitSpecResult { compute, http_port })
}
struct WaitSpecResult {
compute: Arc<ComputeNode>,
// passed through from ProcessCliResult
http_port: u16,
}
fn start_postgres(
matches: &clap::ArgMatches,
WaitSpecResult { compute, http_port }: WaitSpecResult,
) -> Result<(Option<PostgresHandle>, StartPostgresResult)> {
// We got all we need, update the state.
let mut state = compute.state.lock().unwrap();
@@ -281,9 +365,10 @@ fn main() -> Result<()> {
let _monitor_handle = launch_monitor(&compute);
let _configurator_handle = launch_configurator(&compute);
let extension_server_port: u16 = http_port;
// Start Postgres
let mut delay_exit = false;
let mut exit_code = None;
let pg = match compute.start_compute(extension_server_port) {
Ok(pg) => Some(pg),
Err(err) => {
@@ -334,7 +419,7 @@ fn main() -> Result<()> {
// This token is used internally by the monitor to clean up all threads
let token = CancellationToken::new();
let vm_monitor = &rt.as_ref().map(|rt| {
let vm_monitor = rt.as_ref().map(|rt| {
rt.spawn(vm_monitor::start(
Box::leak(Box::new(vm_monitor::Args {
cgroup: cgroup.cloned(),
@@ -347,12 +432,43 @@ fn main() -> Result<()> {
}
}
Ok((
pg,
StartPostgresResult {
delay_exit,
compute,
#[cfg(target_os = "linux")]
rt,
#[cfg(target_os = "linux")]
token,
#[cfg(target_os = "linux")]
vm_monitor,
},
))
}
type PostgresHandle = (std::process::Child, std::thread::JoinHandle<()>);
struct StartPostgresResult {
delay_exit: bool,
// passed through from WaitSpecResult
compute: Arc<ComputeNode>,
#[cfg(target_os = "linux")]
rt: Option<tokio::runtime::Runtime>,
#[cfg(target_os = "linux")]
token: tokio_util::sync::CancellationToken,
#[cfg(target_os = "linux")]
vm_monitor: Option<tokio::task::JoinHandle<Result<()>>>,
}
fn wait_postgres(
pg: Option<PostgresHandle>,
) -> Result<WaitPostgresResult> {
// Wait for the child Postgres process forever. In this state Ctrl+C will
// propagate to Postgres and it will be shut down as well.
let mut exit_code = None;
if let Some((mut pg, logs_handle)) = pg {
// Startup is finished, exit the startup tracing span
drop(startup_context_guard);
let ecode = pg
.wait()
.expect("failed to start waiting on Postgres process");
@@ -367,6 +483,26 @@ fn main() -> Result<()> {
exit_code = ecode.code()
}
Ok(WaitPostgresResult { exit_code })
}
struct WaitPostgresResult {
exit_code: Option<i32>,
}
fn cleanup_and_exit(
StartPostgresResult {
mut delay_exit,
compute,
#[cfg(target_os = "linux")]
vm_monitor,
#[cfg(target_os = "linux")]
token,
#[cfg(target_os = "linux")]
rt,
}: StartPostgresResult,
WaitPostgresResult { exit_code }: WaitPostgresResult,
) -> Result<()> {
// Terminate the vm_monitor so it releases the file watcher on
// /sys/fs/cgroup/neon-postgres.
// Note: the vm-monitor only runs on linux because it requires cgroups.

View File

@@ -10,7 +10,7 @@
//! This module is responsible for creation of such tarball
//! from data stored in object storage.
//!
use anyhow::{anyhow, bail, ensure, Context};
use anyhow::{anyhow, Context};
use bytes::{BufMut, Bytes, BytesMut};
use fail::fail_point;
use pageserver_api::key::{key_to_slru_block, Key};
@@ -38,6 +38,14 @@ use postgres_ffi::PG_TLI;
use postgres_ffi::{BLCKSZ, RELSEG_SIZE, WAL_SEGMENT_SIZE};
use utils::lsn::Lsn;
#[derive(Debug, thiserror::Error)]
pub enum BasebackupError {
#[error("basebackup pageserver error {0:#}")]
Server(#[from] anyhow::Error),
#[error("basebackup client error {0:#}")]
Client(#[source] io::Error),
}
/// Create basebackup with non-rel data in it.
/// Only include relational data if 'full_backup' is true.
///
@@ -53,7 +61,7 @@ pub async fn send_basebackup_tarball<'a, W>(
prev_lsn: Option<Lsn>,
full_backup: bool,
ctx: &'a RequestContext,
) -> anyhow::Result<()>
) -> Result<(), BasebackupError>
where
W: AsyncWrite + Send + Sync + Unpin,
{
@@ -92,8 +100,10 @@ where
// Consolidate the derived and the provided prev_lsn values
let prev_lsn = if let Some(provided_prev_lsn) = prev_lsn {
if backup_prev != Lsn(0) {
ensure!(backup_prev == provided_prev_lsn);
if backup_prev != Lsn(0) && backup_prev != provided_prev_lsn {
return Err(BasebackupError::Server(anyhow!(
"backup_prev {backup_prev} != provided_prev_lsn {provided_prev_lsn}"
)));
}
provided_prev_lsn
} else {
@@ -159,15 +169,26 @@ where
}
}
async fn add_block(&mut self, key: &Key, block: Bytes) -> anyhow::Result<()> {
async fn add_block(&mut self, key: &Key, block: Bytes) -> Result<(), BasebackupError> {
let (kind, segno, _) = key_to_slru_block(*key)?;
match kind {
SlruKind::Clog => {
ensure!(block.len() == BLCKSZ as usize || block.len() == BLCKSZ as usize + 8);
if !(block.len() == BLCKSZ as usize || block.len() == BLCKSZ as usize + 8) {
return Err(BasebackupError::Server(anyhow!(
"invalid SlruKind::Clog record: block.len()={}",
block.len()
)));
}
}
SlruKind::MultiXactMembers | SlruKind::MultiXactOffsets => {
ensure!(block.len() == BLCKSZ as usize);
if block.len() != BLCKSZ as usize {
return Err(BasebackupError::Server(anyhow!(
"invalid {:?} record: block.len()={}",
kind,
block.len()
)));
}
}
}
@@ -194,12 +215,15 @@ where
Ok(())
}
async fn flush(&mut self) -> anyhow::Result<()> {
async fn flush(&mut self) -> Result<(), BasebackupError> {
let nblocks = self.buf.len() / BLCKSZ as usize;
let (kind, segno) = self.current_segment.take().unwrap();
let segname = format!("{}/{:>04X}", kind.to_str(), segno);
let header = new_tar_header(&segname, self.buf.len() as u64)?;
self.ar.append(&header, self.buf.as_slice()).await?;
self.ar
.append(&header, self.buf.as_slice())
.await
.map_err(BasebackupError::Client)?;
self.total_blocks += nblocks;
debug!("Added to basebackup slru {} relsize {}", segname, nblocks);
@@ -209,7 +233,7 @@ where
Ok(())
}
async fn finish(mut self) -> anyhow::Result<()> {
async fn finish(mut self) -> Result<(), BasebackupError> {
let res = if self.current_segment.is_none() || self.buf.is_empty() {
Ok(())
} else {
@@ -226,7 +250,7 @@ impl<'a, W> Basebackup<'a, W>
where
W: AsyncWrite + Send + Sync + Unpin,
{
async fn send_tarball(mut self) -> anyhow::Result<()> {
async fn send_tarball(mut self) -> Result<(), BasebackupError> {
// TODO include checksum
let lazy_slru_download = self.timeline.get_lazy_slru_download() && !self.full_backup;
@@ -262,7 +286,8 @@ where
let slru_partitions = self
.timeline
.get_slru_keyspace(Version::Lsn(self.lsn), self.ctx)
.await?
.await
.map_err(|e| BasebackupError::Server(e.into()))?
.partition(
self.timeline.get_shard_identity(),
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
@@ -271,10 +296,15 @@ where
let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar);
for part in slru_partitions.parts {
let blocks = self.timeline.get_vectored(part, self.lsn, self.ctx).await?;
let blocks = self
.timeline
.get_vectored(part, self.lsn, self.ctx)
.await
.map_err(|e| BasebackupError::Server(e.into()))?;
for (key, block) in blocks {
slru_builder.add_block(&key, block?).await?;
let block = block.map_err(|e| BasebackupError::Server(e.into()))?;
slru_builder.add_block(&key, block).await?;
}
}
slru_builder.finish().await?;
@@ -282,8 +312,11 @@ where
let mut min_restart_lsn: Lsn = Lsn::MAX;
// Create tablespace directories
for ((spcnode, dbnode), has_relmap_file) in
self.timeline.list_dbdirs(self.lsn, self.ctx).await?
for ((spcnode, dbnode), has_relmap_file) in self
.timeline
.list_dbdirs(self.lsn, self.ctx)
.await
.map_err(|e| BasebackupError::Server(e.into()))?
{
self.add_dbdir(spcnode, dbnode, has_relmap_file).await?;
@@ -292,7 +325,8 @@ where
let rels = self
.timeline
.list_rels(spcnode, dbnode, Version::Lsn(self.lsn), self.ctx)
.await?;
.await
.map_err(|e| BasebackupError::Server(e.into()))?;
for &rel in rels.iter() {
// Send init fork as main fork to provide well formed empty
// contents of UNLOGGED relations. Postgres copies it in
@@ -315,7 +349,12 @@ where
}
}
for (path, content) in self.timeline.list_aux_files(self.lsn, self.ctx).await? {
for (path, content) in self
.timeline
.list_aux_files(self.lsn, self.ctx)
.await
.map_err(|e| BasebackupError::Server(e.into()))?
{
if path.starts_with("pg_replslot") {
let offs = pg_constants::REPL_SLOT_ON_DISK_OFFSETOF_RESTART_LSN;
let restart_lsn = Lsn(u64::from_le_bytes(
@@ -346,34 +385,41 @@ where
for xid in self
.timeline
.list_twophase_files(self.lsn, self.ctx)
.await?
.await
.map_err(|e| BasebackupError::Server(e.into()))?
{
self.add_twophase_file(xid).await?;
}
fail_point!("basebackup-before-control-file", |_| {
bail!("failpoint basebackup-before-control-file")
Err(BasebackupError::Server(anyhow!(
"failpoint basebackup-before-control-file"
)))
});
// Generate pg_control and bootstrap WAL segment.
self.add_pgcontrol_file().await?;
self.ar.finish().await?;
self.ar.finish().await.map_err(BasebackupError::Client)?;
debug!("all tarred up!");
Ok(())
}
/// Add contents of relfilenode `src`, naming it as `dst`.
async fn add_rel(&mut self, src: RelTag, dst: RelTag) -> anyhow::Result<()> {
async fn add_rel(&mut self, src: RelTag, dst: RelTag) -> Result<(), BasebackupError> {
let nblocks = self
.timeline
.get_rel_size(src, Version::Lsn(self.lsn), self.ctx)
.await?;
.await
.map_err(|e| BasebackupError::Server(e.into()))?;
// If the relation is empty, create an empty file
if nblocks == 0 {
let file_name = dst.to_segfile_name(0);
let header = new_tar_header(&file_name, 0)?;
self.ar.append(&header, &mut io::empty()).await?;
self.ar
.append(&header, &mut io::empty())
.await
.map_err(BasebackupError::Client)?;
return Ok(());
}
@@ -388,13 +434,17 @@ where
let img = self
.timeline
.get_rel_page_at_lsn(src, blknum, Version::Lsn(self.lsn), self.ctx)
.await?;
.await
.map_err(|e| BasebackupError::Server(e.into()))?;
segment_data.extend_from_slice(&img[..]);
}
let file_name = dst.to_segfile_name(seg as u32);
let header = new_tar_header(&file_name, segment_data.len() as u64)?;
self.ar.append(&header, segment_data.as_slice()).await?;
self.ar
.append(&header, segment_data.as_slice())
.await
.map_err(BasebackupError::Client)?;
seg += 1;
startblk = endblk;
@@ -414,20 +464,22 @@ where
spcnode: u32,
dbnode: u32,
has_relmap_file: bool,
) -> anyhow::Result<()> {
) -> Result<(), BasebackupError> {
let relmap_img = if has_relmap_file {
let img = self
.timeline
.get_relmap_file(spcnode, dbnode, Version::Lsn(self.lsn), self.ctx)
.await?;
.await
.map_err(|e| BasebackupError::Server(e.into()))?;
ensure!(
img.len()
== dispatch_pgversion!(
self.timeline.pg_version,
pgv::bindings::SIZEOF_RELMAPFILE
)
);
if img.len()
!= dispatch_pgversion!(self.timeline.pg_version, pgv::bindings::SIZEOF_RELMAPFILE)
{
return Err(BasebackupError::Server(anyhow!(
"img.len() != SIZE_OF_RELMAPFILE, img.len()={}",
img.len(),
)));
}
Some(img)
} else {
@@ -440,14 +492,20 @@ where
ver => format!("{ver}\x0A"),
};
let header = new_tar_header("PG_VERSION", pg_version_str.len() as u64)?;
self.ar.append(&header, pg_version_str.as_bytes()).await?;
self.ar
.append(&header, pg_version_str.as_bytes())
.await
.map_err(BasebackupError::Client)?;
info!("timeline.pg_version {}", self.timeline.pg_version);
if let Some(img) = relmap_img {
// filenode map for global tablespace
let header = new_tar_header("global/pg_filenode.map", img.len() as u64)?;
self.ar.append(&header, &img[..]).await?;
self.ar
.append(&header, &img[..])
.await
.map_err(BasebackupError::Client)?;
} else {
warn!("global/pg_filenode.map is missing");
}
@@ -466,18 +524,26 @@ where
&& self
.timeline
.list_rels(spcnode, dbnode, Version::Lsn(self.lsn), self.ctx)
.await?
.await
.map_err(|e| BasebackupError::Server(e.into()))?
.is_empty()
{
return Ok(());
}
// User defined tablespaces are not supported
ensure!(spcnode == DEFAULTTABLESPACE_OID);
if spcnode != DEFAULTTABLESPACE_OID {
return Err(BasebackupError::Server(anyhow!(
"spcnode != DEFAULTTABLESPACE_OID, spcnode={spcnode}"
)));
}
// Append dir path for each database
let path = format!("base/{}", dbnode);
let header = new_tar_header_dir(&path)?;
self.ar.append(&header, &mut io::empty()).await?;
self.ar
.append(&header, &mut io::empty())
.await
.map_err(BasebackupError::Client)?;
if let Some(img) = relmap_img {
let dst_path = format!("base/{}/PG_VERSION", dbnode);
@@ -487,11 +553,17 @@ where
ver => format!("{ver}\x0A"),
};
let header = new_tar_header(&dst_path, pg_version_str.len() as u64)?;
self.ar.append(&header, pg_version_str.as_bytes()).await?;
self.ar
.append(&header, pg_version_str.as_bytes())
.await
.map_err(BasebackupError::Client)?;
let relmap_path = format!("base/{}/pg_filenode.map", dbnode);
let header = new_tar_header(&relmap_path, img.len() as u64)?;
self.ar.append(&header, &img[..]).await?;
self.ar
.append(&header, &img[..])
.await
.map_err(BasebackupError::Client)?;
}
};
Ok(())
@@ -500,11 +572,12 @@ where
//
// Extract twophase state files
//
async fn add_twophase_file(&mut self, xid: TransactionId) -> anyhow::Result<()> {
async fn add_twophase_file(&mut self, xid: TransactionId) -> Result<(), BasebackupError> {
let img = self
.timeline
.get_twophase_file(xid, self.lsn, self.ctx)
.await?;
.await
.map_err(|e| BasebackupError::Server(e.into()))?;
let mut buf = BytesMut::new();
buf.extend_from_slice(&img[..]);
@@ -512,7 +585,10 @@ where
buf.put_u32_le(crc);
let path = format!("pg_twophase/{:>08X}", xid);
let header = new_tar_header(&path, buf.len() as u64)?;
self.ar.append(&header, &buf[..]).await?;
self.ar
.append(&header, &buf[..])
.await
.map_err(BasebackupError::Client)?;
Ok(())
}
@@ -521,24 +597,28 @@ where
// Add generated pg_control file and bootstrap WAL segment.
// Also send zenith.signal file with extra bootstrap data.
//
async fn add_pgcontrol_file(&mut self) -> anyhow::Result<()> {
async fn add_pgcontrol_file(&mut self) -> Result<(), BasebackupError> {
// add zenith.signal file
let mut zenith_signal = String::new();
if self.prev_record_lsn == Lsn(0) {
if self.lsn == self.timeline.get_ancestor_lsn() {
write!(zenith_signal, "PREV LSN: none")?;
write!(zenith_signal, "PREV LSN: none")
.map_err(|e| BasebackupError::Server(e.into()))?;
} else {
write!(zenith_signal, "PREV LSN: invalid")?;
write!(zenith_signal, "PREV LSN: invalid")
.map_err(|e| BasebackupError::Server(e.into()))?;
}
} else {
write!(zenith_signal, "PREV LSN: {}", self.prev_record_lsn)?;
write!(zenith_signal, "PREV LSN: {}", self.prev_record_lsn)
.map_err(|e| BasebackupError::Server(e.into()))?;
}
self.ar
.append(
&new_tar_header("zenith.signal", zenith_signal.len() as u64)?,
zenith_signal.as_bytes(),
)
.await?;
.await
.map_err(BasebackupError::Client)?;
let checkpoint_bytes = self
.timeline
@@ -560,7 +640,10 @@ where
//send pg_control
let header = new_tar_header("global/pg_control", pg_control_bytes.len() as u64)?;
self.ar.append(&header, &pg_control_bytes[..]).await?;
self.ar
.append(&header, &pg_control_bytes[..])
.await
.map_err(BasebackupError::Client)?;
//send wal segment
let segno = self.lsn.segment_number(WAL_SEGMENT_SIZE);
@@ -575,8 +658,16 @@ where
self.lsn,
)
.map_err(|e| anyhow!(e).context("Failed generating wal segment"))?;
ensure!(wal_seg.len() == WAL_SEGMENT_SIZE);
self.ar.append(&header, &wal_seg[..]).await?;
if wal_seg.len() != WAL_SEGMENT_SIZE {
return Err(BasebackupError::Server(anyhow!(
"wal_seg.len() != WAL_SEGMENT_SIZE, wal_seg.len()={}",
wal_seg.len()
)));
}
self.ar
.append(&header, &wal_seg[..])
.await
.map_err(BasebackupError::Client)?;
Ok(())
}
}

View File

@@ -48,6 +48,7 @@ use utils::{
use crate::auth::check_permission;
use crate::basebackup;
use crate::basebackup::BasebackupError;
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
use crate::import_datadir::import_wal_from_tar;
@@ -1236,6 +1237,13 @@ impl PageServerHandler {
where
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin,
{
fn map_basebackup_error(err: BasebackupError) -> QueryError {
match err {
BasebackupError::Client(e) => QueryError::Disconnected(ConnectionError::Io(e)),
BasebackupError::Server(e) => QueryError::Other(e),
}
}
let started = std::time::Instant::now();
// check that the timeline exists
@@ -1261,7 +1269,8 @@ impl PageServerHandler {
let lsn_awaited_after = started.elapsed();
// switch client to COPYOUT
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
pgb.write_message_noflush(&BeMessage::CopyOutResponse)
.map_err(QueryError::Disconnected)?;
self.flush_cancellable(pgb, &timeline.cancel).await?;
// Send a tarball of the latest layer on the timeline. Compress if not
@@ -1276,7 +1285,8 @@ impl PageServerHandler {
full_backup,
ctx,
)
.await?;
.await
.map_err(map_basebackup_error)?;
} else {
let mut writer = pgb.copyout_writer();
if gzip {
@@ -1297,9 +1307,13 @@ impl PageServerHandler {
full_backup,
ctx,
)
.await?;
.await
.map_err(map_basebackup_error)?;
// shutdown the encoder to ensure the gzip footer is written
encoder.shutdown().await?;
encoder
.shutdown()
.await
.map_err(|e| QueryError::Disconnected(ConnectionError::Io(e)))?;
} else {
basebackup::send_basebackup_tarball(
&mut writer,
@@ -1309,11 +1323,13 @@ impl PageServerHandler {
full_backup,
ctx,
)
.await?;
.await
.map_err(map_basebackup_error)?;
}
}
pgb.write_message_noflush(&BeMessage::CopyDone)?;
pgb.write_message_noflush(&BeMessage::CopyDone)
.map_err(QueryError::Disconnected)?;
self.flush_cancellable(pgb, &timeline.cancel).await?;
let basebackup_after = started

View File

@@ -401,8 +401,8 @@ impl Layer {
&self.0.path
}
pub(crate) fn local_path_str(&self) -> &Arc<str> {
&self.0.path_str
pub(crate) fn debug_str(&self) -> &Arc<str> {
&self.0.debug_str
}
pub(crate) fn metadata(&self) -> LayerFileMetadata {
@@ -527,8 +527,8 @@ struct LayerInner {
/// Full path to the file; unclear if this should exist anymore.
path: Utf8PathBuf,
/// String representation of the full path, used for traversal id.
path_str: Arc<str>,
/// String representation of the layer, used for traversal id.
debug_str: Arc<str>,
desc: PersistentLayerDesc,
@@ -735,7 +735,7 @@ impl LayerInner {
LayerInner {
conf,
path_str: path.to_string().into(),
debug_str: { format!("timelines/{}/{}", timeline.timeline_id, desc.filename()).into() },
path,
desc,
timeline: Arc::downgrade(timeline),

View File

@@ -2948,7 +2948,7 @@ trait TraversalLayerExt {
impl TraversalLayerExt for Layer {
fn traversal_id(&self) -> TraversalId {
Arc::clone(self.local_path_str())
Arc::clone(self.debug_str())
}
}

View File

@@ -100,7 +100,6 @@ postgres-protocol.workspace = true
redis.workspace = true
workspace_hack.workspace = true
pin-list = { version = "0.1.0", features = ["std"] }
[dev-dependencies]
camino-tempfile.workspace = true

View File

@@ -16,7 +16,7 @@ use crate::{
proxy::connect_compute::ConnectMechanism,
};
use super::conn_pool::{poll_tokio_client, Client, ConnInfo, GlobalConnPool};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -184,10 +184,10 @@ impl ConnectMechanism for TokioMechanism {
drop(pause);
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
Ok(poll_tokio_client(
Ok(poll_client(
self.pool.clone(),
ctx,
&self.conn_info,
self.conn_info.clone(),
client,
connection,
self.conn_id,

View File

@@ -1,12 +1,9 @@
use dashmap::DashMap;
use futures::Future;
use futures::{future::poll_fn, Future};
use parking_lot::RwLock;
use pin_list::{InitializedNode, Node};
use pin_project_lite::pin_project;
use rand::Rng;
use smallvec::SmallVec;
use std::pin::Pin;
use std::{collections::HashMap, sync::Arc, time::Duration};
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use std::{
fmt,
task::{ready, Poll},
@@ -15,19 +12,19 @@ use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio::sync::mpsc::error::TrySendError;
use tokio::time::Sleep;
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics, NumDbConnectionsGuard};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
};
use tracing::{debug, error, warn};
use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
@@ -86,11 +83,7 @@ pub struct EndpointConnPool<C: ClientInnerExt> {
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(
&mut self,
db_user: (DbName, RoleName),
session_id: uuid::Uuid,
) -> Option<ConnPoolEntry<C>> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools,
total_conns,
@@ -98,15 +91,11 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
..
} = self;
pools.get_mut(&db_user).and_then(|pool_entries| {
pool_entries.get_conn_entry(total_conns, global_connections_count, session_id)
pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
})
}
fn remove_client<'a>(
&mut self,
db_user: (DbName, RoleName),
node: Pin<&'a mut InitializedNode<'a, ConnTypes<C>>>,
) -> bool {
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
let Self {
pools,
total_conns,
@@ -114,39 +103,41 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
if node.unlink(&mut pool.conns).is_ok() {
global_connections_count.fetch_sub(1, atomic::Ordering::Relaxed);
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(1);
*total_conns -= 1;
true
} else {
false
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
fn put(
pool: &RwLock<Self>,
node: Pin<&mut Node<ConnTypes<C>>>,
db_user: &(DbName, RoleName),
client: ClientInner<C>,
) -> bool {
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
.read()
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= global_max_conn
{
let pool = pool.read();
if pool
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= pool.global_pool_size_max_conns
{
info!("pool: throwing away connection because pool is full");
return false;
}
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
@@ -156,19 +147,14 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(db_user.clone()).or_default();
pool_entries.conns.cursor_front_mut().insert_after(
node,
ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
},
(),
);
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.len;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
pool.global_connections_count
@@ -185,12 +171,10 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
// do logging outside of the mutex
if returned {
info!("pool: returning connection back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!("pool: throwing away connection because pool is full, total_conns={total_conns}");
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
returned
}
}
@@ -209,37 +193,45 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
}
pub struct DbUserConnPool<C: ClientInnerExt> {
conns: pin_list::PinList<ConnTypes<C>>,
len: usize,
conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self {
conns: pin_list::PinList::new(pin_list::id::Checked::new()),
len: 0,
}
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
fn get_conn_entry(
&mut self,
conns: &mut usize,
global_connections_count: &AtomicUsize,
session_id: uuid::Uuid,
global_connections_count: Arc<AtomicUsize>,
) -> Option<ConnPoolEntry<C>> {
let conn = self
.conns
.cursor_front_mut()
.remove_current(session_id)
.ok()?;
*conns -= 1;
global_connections_count.fetch_sub(1, atomic::Ordering::Relaxed);
Metrics::get().proxy.http_pool_opened_connections.dec_by(1);
Some(conn)
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
conn
}
}
@@ -331,11 +323,19 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { total_conns, .. } = pool.get_mut();
let EndpointConnPool {
pools, total_conns, ..
} = pool.get_mut();
// ensure that closed clients are removed
pools.iter_mut().for_each(|(_, db_pool)| {
clients_removed += db_pool.clear_closed_clients(total_conns);
});
// we only remove this pool if it has no active connections
if *total_conns == 0 {
@@ -351,6 +351,19 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
@@ -375,25 +388,32 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
if let Some(entry) = endpoint_pool
.write()
.get_conn_entry(conn_info.db_and_user(), ctx.session_id)
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn)
}
let endpoint_pool = Arc::downgrade(&endpoint_pool);
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
&tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.latency_timer.success();
return Ok(Some(Client::new(client)));
if client.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
} else {
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
&tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id)?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.latency_timer.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
}
Ok(None)
}
@@ -443,252 +463,154 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
}
}
type ConnTypes<C> = dyn pin_list::Types<
Id = pin_list::id::Checked,
Protected = ConnPoolEntry<C>,
// session ID
Removed = uuid::Uuid,
Unprotected = (),
>;
pub fn poll_tokio_client(
global_pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,
client: tokio_postgres::Client,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<tokio_postgres::Client> {
let connection = std::future::poll_fn(move |cx| {
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!("notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(
pid = notif.process_id(),
channel = notif.channel(),
"notification received"
);
}
Some(Ok(_)) => {
warn!("unknown message");
}
Some(Err(e)) => {
error!("connection error: {}", e);
break;
}
None => {
info!("connection closed");
break;
}
}
}
Poll::Ready(())
});
poll_client(
global_pool,
ctx,
conn_info,
client,
connection,
conn_id,
aux,
)
}
pub fn poll_client<C: ClientInnerExt, I: Future<Output = ()> + Send + 'static>(
pub fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C>>,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,
conn_info: ConnInfo,
client: C,
connection: I,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol);
let session_id = ctx.session_id;
let mut session_id = ctx.session_id;
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info;
let session_span = info_span!(parent: span.clone(), "", %session_id);
session_span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, "new connection");
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
None => Weak::new(),
};
let pool_clone = pool.clone();
let pool = conn_info
.endpoint_cache_key()
.map(|endpoint| global_pool.get_or_create_endpoint_pool(&endpoint));
let db_user = conn_info.db_and_user();
let idle = global_pool.get_idle_timeout();
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
let (send_client, recv_client) = tokio::sync::mpsc::channel(1);
let db_conn = DbConnection {
idle_timeout: None,
idle,
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let mut idle_timeout = pin!(tokio::time::sleep(idle));
let mut cancelled = pin!(cancelled);
node: Node::<ConnTypes<C>>::new(),
recv_client,
db_user: conn_info.db_and_user(),
pool,
session_span,
conn_gauge,
connection,
};
tokio::spawn(db_conn.instrument(span));
let inner = ClientInner {
inner: client,
pool: send_client,
aux,
conn_id,
};
Client::new(inner)
}
pin_project! {
struct DbConnection<C: ClientInnerExt, Inner> {
// Used to close the current conn if it's idle
#[pin]
idle_timeout: Option<Sleep>,
idle: tokio::time::Duration,
// Used to add/remove conn from the conn pool
#[pin]
node: Node<ConnTypes<C>>,
recv_client: tokio::sync::mpsc::Receiver<ClientInner<C>>,
db_user: (DbName, RoleName),
pool: Option<Arc<RwLock<EndpointConnPool<C>>>>,
// Used for reporting the current session the conn is attached to
session_span: tracing::Span,
// Static connection state
conn_gauge: NumDbConnectionsGuard<'static>,
#[pin]
connection: Inner,
}
impl<C: ClientInnerExt, I> PinnedDrop for DbConnection<C, I> {
fn drop(this: Pin<&mut Self>) {
let mut this = this.project();
let Some(init) = this.node.as_mut().initialized_mut() else { return };
let pool = this.pool.as_ref().expect("pool must be set if the node is initialsed in the pool");
if pool.write().remove_client(this.db_user.clone(), init) {
info!("closed connection removed");
poll_fn(move |cx| {
if cancelled.as_mut().poll(cx).is_ready() {
info!("connection dropped");
return Poll::Ready(())
}
}
}
}
impl<C: ClientInnerExt, I: Future<Output = ()>> Future for DbConnection<C, I> {
type Output = ();
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
// Update the session span.
// If the node is initialised, then it is either
// 1. Waiting in the idle pool
// 2. Just removed from the idle pool and this is our first wake up.
//
// In the event of 1, nothing happens. (should not have many wakeups while idle)
// In the event of 2, we remove the session_id that was left in it's place.
if let Some(init) = this.node.as_mut().initialized_mut() {
// node is initiated via EndpointConnPool::put.
// this is only called in the if statement below.
// this can only occur if pool is set (and pool is never removed).
// when this occurs, it guarantees that the DbUserConnPool is created (it is never removed).
let pool = this
.pool
.as_ref()
.expect("node cannot be init without pool");
let mut pool_lock = pool.write();
let db_pool = pool_lock
.pools
.get(this.db_user)
.expect("node cannot be init without pool");
match init.take_removed(&db_pool.conns) {
Ok((session_id, _)) => {
*this.session_span = info_span!("", %session_id);
let _span = this.session_span.enter();
info!("changed session");
// this connection is no longer idle
this.idle_timeout.set(None);
match rx.has_changed() {
Ok(true) => {
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
Err(init) => {
let idle = this
.idle_timeout
.as_mut()
.as_pin_mut()
.expect("timer must be set if node is init");
Err(_) => {
info!("connection dropped");
return Poll::Ready(())
}
_ => {}
}
if idle.poll(cx).is_ready() {
info!("connection idle");
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool_lock.remove_client(this.db_user.clone(), init) {
info!("closed connection removed");
}
// 5 minute idle connection timeout
if idle_timeout.as_mut().poll(cx).is_ready() {
idle_timeout.as_mut().reset(Instant::now() + idle);
info!("connection idle");
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.write().remove_client(db_user.clone(), conn_id) {
info!("idle connection removed");
}
}
}
}
let _span = this.session_span.enter();
loop {
let message = ready!(connection.poll_message(cx));
// The client has been returned. We will insert it into the linked list for this database.
if let Poll::Ready(client) = this.recv_client.poll_recv(cx) {
// if the send_client is dropped, then the client is dropped
let Some(client) = client else {
info!("connection dropped");
return Poll::Ready(());
};
// if there's no pool, then this client will be closed.
let Some(pool) = &this.pool else {
info!("connection dropped");
return Poll::Ready(());
};
if !EndpointConnPool::put(pool, this.node.as_mut(), this.db_user, client) {
return Poll::Ready(());
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);
break
}
None => {
info!("connection closed");
break
}
}
}
// this connection is now idle
this.idle_timeout.set(Some(tokio::time::sleep(*this.idle)));
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;
this.connection.poll(cx)
}
.instrument(span));
let inner = ClientInner {
inner: client,
session: tx,
cancel,
aux,
conn_id,
};
Client::new(inner, conn_info, pool_clone)
}
struct ClientInner<C: ClientInnerExt> {
inner: C,
pool: tokio::sync::mpsc::Sender<ClientInner<C>>,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
aux: MetricsAuxInfo,
conn_id: uuid::Uuid,
}
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
self.cancel.cancel();
}
}
pub trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
}
impl ClientInnerExt for tokio_postgres::Client {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn get_process_id(&self) -> i32 {
self.get_process_id()
}
}
impl<C: ClientInnerExt> ClientInner<C> {
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> Client<C> {
pub fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
@@ -700,42 +622,54 @@ impl<C: ClientInnerExt> Client<C> {
}
pub struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
discarded: bool,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
}
pub struct Discard<'a> {
conn_id: uuid::Uuid,
discarded: &'a mut bool,
pub struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Client<C> {
pub(self) fn new(inner: ClientInner<C>) -> Self {
pub(self) fn new(
inner: ClientInner<C>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
) -> Self {
Self {
inner: Some(inner),
discarded: false,
span: Span::current(),
conn_info,
pool,
}
}
pub fn inner(&mut self) -> (&mut C, Discard<'_>) {
let Self { inner, discarded } = self;
pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
let conn_id = inner.conn_id;
(&mut inner.inner, Discard { discarded, conn_id })
(&mut inner.inner, Discard { pool, conn_info })
}
}
impl Discard<'_> {
impl<C: ClientInnerExt> Discard<'_, C> {
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_id = &self.conn_id;
if status != ReadyForQueryStatus::Idle && !*self.discarded {
*self.discarded = true;
info!(%conn_id, "pool: throwing away connection because connection is not idle")
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
}
}
pub fn discard(&mut self) {
let conn_id = &self.conn_id;
*self.discarded = true;
info!(%conn_id, "pool: throwing away connection because connection is potentially in a broken state")
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
}
}
}
@@ -751,68 +685,73 @@ impl<C: ClientInnerExt> Deref for Client<C> {
}
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
impl<C: ClientInnerExt> Client<C> {
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if self.discarded {
return;
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
}
None
}
}
let conn_id = client.conn_id;
let tx = client.pool.clone();
match tx.try_send(client) {
Ok(_) => {}
Err(TrySendError::Closed(_)) => {
info!(%conn_id, "pool: throwing away connection because connection is closed");
}
Err(TrySendError::Full(_)) => {
error!("client channel should not be full")
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
}
}
#[cfg(test)]
mod tests {
use tokio::task::yield_now;
use tokio_util::sync::CancellationToken;
use std::{mem, sync::atomic::AtomicBool};
use crate::{BranchId, EndpointId, ProjectId};
use super::*;
struct MockClient;
struct MockClient(Arc<AtomicBool>);
impl MockClient {
fn new(is_closed: bool) -> Self {
MockClient(Arc::new(is_closed.into()))
}
}
impl ClientInnerExt for MockClient {
fn is_closed(&self) -> bool {
self.0.load(atomic::Ordering::Relaxed)
}
fn get_process_id(&self) -> i32 {
0
}
}
fn create_inner(
global_pool: Arc<GlobalConnPool<MockClient>>,
conn_info: &ConnInfo,
) -> (Client<MockClient>, CancellationToken) {
let cancelled = CancellationToken::new();
let client = poll_client(
global_pool,
&mut RequestMonitoring::test(),
conn_info,
MockClient,
cancelled.clone().cancelled_owned(),
uuid::Uuid::new_v4(),
MetricsAuxInfo {
fn create_inner() -> ClientInner<MockClient> {
create_inner_with(MockClient::new(false))
}
fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
ClientInner {
inner: client,
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
cancel: CancellationToken::new(),
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
branch_id: (&BranchId::from("branch")).into(),
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
},
);
(client, cancelled)
conn_id: uuid::Uuid::new_v4(),
}
}
#[tokio::test]
@@ -839,41 +778,51 @@ mod tests {
dbname: "dbname".into(),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
);
{
let (mut client, _) = create_inner(pool.clone(), &conn_info);
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
assert_eq!(0, pool.get_global_connections_count());
client.inner().1.discard();
drop(client);
yield_now().await;
// Discard should not add the connection from the pool.
assert_eq!(0, pool.get_global_connections_count());
}
{
let (client, _) = create_inner(pool.clone(), &conn_info);
drop(client);
yield_now().await;
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
assert_eq!(1, pool.get_global_connections_count());
}
{
let (client, cancel) = create_inner(pool.clone(), &conn_info);
cancel.cancel();
drop(client);
yield_now().await;
// The closed client shouldn't be added to the pool.
let mut closed_client = Client::new(
create_inner_with(MockClient::new(true)),
conn_info.clone(),
ep_pool.clone(),
);
closed_client.do_drop().unwrap()();
mem::forget(closed_client); // drop the client
// The closed client shouldn't be added to the pool.
assert_eq!(1, pool.get_global_connections_count());
}
let cancel = {
let (client, cancel) = create_inner(pool.clone(), &conn_info);
drop(client);
yield_now().await;
let is_closed: Arc<AtomicBool> = Arc::new(false.into());
{
let mut client = Client::new(
create_inner_with(MockClient(is_closed.clone())),
conn_info.clone(),
ep_pool.clone(),
);
client.do_drop().unwrap()();
mem::forget(client); // drop the client
// The client should be added to the pool.
assert_eq!(2, pool.get_global_connections_count());
cancel
};
}
{
let client = create_inner(pool.clone(), &conn_info);
drop(client);
yield_now().await;
let mut client = Client::new(create_inner(), conn_info, ep_pool);
client.do_drop().unwrap()();
mem::forget(client); // drop the client
// The client shouldn't be added to the pool. Because the ep-pool is full.
assert_eq!(2, pool.get_global_connections_count());
}
@@ -887,22 +836,25 @@ mod tests {
dbname: "dbname".into(),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
);
{
let client = create_inner(pool.clone(), &conn_info);
drop(client);
yield_now().await;
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
assert_eq!(3, pool.get_global_connections_count());
}
{
let client = create_inner(pool.clone(), &conn_info);
drop(client);
yield_now().await;
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
// The client shouldn't be added to the pool. Because the global pool is full.
assert_eq!(3, pool.get_global_connections_count());
}
cancel.cancel();
yield_now().await;
is_closed.store(true, atomic::Ordering::Relaxed);
// Do gc for all shards.
pool.gc(0);
pool.gc(1);