Compare commits

..

5 Commits

Author SHA1 Message Date
Conrad Ludgate
4345fdf07b completely rewrite array text parsing based on spec 2025-05-26 12:35:39 +01:00
Conrad Ludgate
19461604ae move array initialisation out of pg_array_parse 2025-05-26 11:01:57 +01:00
Conrad Ludgate
ef9a5785b0 move value write inside pg_text_to_json 2025-05-26 10:56:28 +01:00
Conrad Ludgate
b339daed9b manage vec/map manually, and handle nulls separately 2025-05-26 10:51:22 +01:00
Conrad Ludgate
85233a85a6 no longer contruct column types vec 2025-05-26 10:35:46 +01:00
80 changed files with 1195 additions and 3464 deletions

15
Cargo.lock generated
View File

@@ -4330,7 +4330,6 @@ dependencies = [
"postgres_connection",
"postgres_ffi",
"postgres_initdb",
"posthog_client_lite",
"pprof",
"pq_proto",
"procfs",
@@ -4459,15 +4458,9 @@ dependencies = [
name = "pageserver_page_api"
version = "0.1.0"
dependencies = [
"bytes",
"pageserver_api",
"postgres_ffi",
"prost 0.13.5",
"smallvec",
"thiserror 1.0.69",
"tonic 0.13.1",
"tonic-build",
"utils",
"workspace_hack",
]
@@ -4908,16 +4901,11 @@ name = "posthog_client_lite"
version = "0.1.0"
dependencies = [
"anyhow",
"arc-swap",
"reqwest",
"serde",
"serde_json",
"sha2",
"thiserror 1.0.69",
"tokio",
"tokio-util",
"tracing",
"tracing-utils",
"workspace_hack",
]
@@ -8581,8 +8569,10 @@ dependencies = [
"fail",
"form_urlencoded",
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-task",
"futures-util",
"generic-array",
"getrandom 0.2.11",
@@ -8612,6 +8602,7 @@ dependencies = [
"once_cell",
"p256 0.13.2",
"parquet",
"percent-encoding",
"prettyplease",
"proc-macro2",
"prost 0.13.5",

View File

@@ -247,7 +247,6 @@ azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rus
## Local libraries
compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
@@ -260,19 +259,19 @@ postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
postgres_initdb = { path = "./libs/postgres_initdb" }
posthog_client_lite = { version = "0.1", path = "./libs/posthog_client_lite" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
remote_storage = { version = "0.1", path = "./libs/remote_storage/" }
safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
safekeeper_client = { path = "./safekeeper/client" }
desim = { version = "0.1", path = "./libs/desim" }
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
storage_controller_client = { path = "./storage_controller/client" }
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
wal_decoder = { version = "0.1", path = "./libs/wal_decoder" }
walproposer = { version = "0.1", path = "./libs/walproposer/" }
wal_decoder = { version = "0.1", path = "./libs/wal_decoder" }
## Common library dependency
workspace_hack = { version = "0.1", path = "./workspace_hack/" }

View File

@@ -1847,7 +1847,7 @@ COPY docker-compose/ext-src/ /ext-src/
COPY --from=pg-build /postgres /postgres
#COPY --from=postgis-src /ext-src/ /ext-src/
COPY --from=plv8-src /ext-src/ /ext-src/
COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src
#COPY --from=h3-pg-src /ext-src/ /ext-src/
COPY --from=postgresql-unit-src /ext-src/ /ext-src/
COPY --from=pgvector-src /ext-src/ /ext-src/
COPY --from=pgjwt-src /ext-src/ /ext-src/

View File

@@ -136,10 +136,6 @@ struct Cli {
requires = "compute-id"
)]
pub control_plane_uri: Option<String>,
/// Interval in seconds for collecting installed extensions statistics
#[arg(long, default_value = "3600")]
pub installed_extensions_collection_interval: u64,
}
fn main() -> Result<()> {
@@ -183,7 +179,6 @@ fn main() -> Result<()> {
cgroup: cli.cgroup,
#[cfg(target_os = "linux")]
vm_monitor_addr: cli.vm_monitor_addr,
installed_extensions_collection_interval: cli.installed_extensions_collection_interval,
},
config,
)?;

View File

@@ -97,9 +97,6 @@ pub struct ComputeNodeParams {
/// the address of extension storage proxy gateway
pub remote_ext_base_url: Option<String>,
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: u64,
}
/// Compute node info shared across several `compute_ctl` threads.
@@ -698,18 +695,25 @@ impl ComputeNode {
let log_directory_path = Path::new(&self.params.pgdata).join("log");
let log_directory_path = log_directory_path.to_string_lossy().to_string();
// Add project_id,endpoint_id to identify the logs.
// Add project_id,endpoint_id tag to identify the logs.
//
// These ids are passed from cplane,
let endpoint_id = pspec.spec.endpoint_id.as_deref().unwrap_or("");
let project_id = pspec.spec.project_id.as_deref().unwrap_or("");
// for backwards compatibility (old computes that don't have them),
// we set them to None.
// TODO: Clean up this code when all computes have them.
let tag: Option<String> = match (
pspec.spec.project_id.as_deref(),
pspec.spec.endpoint_id.as_deref(),
) {
(Some(project_id), Some(endpoint_id)) => {
Some(format!("{project_id}/{endpoint_id}"))
}
(Some(project_id), None) => Some(format!("{project_id}/None")),
(None, Some(endpoint_id)) => Some(format!("None,{endpoint_id}")),
(None, None) => None,
};
configure_audit_rsyslog(
log_directory_path.clone(),
endpoint_id,
project_id,
&remote_endpoint,
)?;
configure_audit_rsyslog(log_directory_path.clone(), tag, &remote_endpoint)?;
// Launch a background task to clean up the audit logs
launch_pgaudit_gc(log_directory_path);
@@ -745,7 +749,17 @@ impl ComputeNode {
let conf = self.get_tokio_conn_conf(None);
tokio::task::spawn(async {
let _ = installed_extensions(conf).await;
let res = get_installed_extensions(conf).await;
match res {
Ok(extensions) => {
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&extensions)
.expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err:?}"),
}
});
}
@@ -775,9 +789,6 @@ impl ComputeNode {
// Log metrics so that we can search for slow operations in logs
info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished");
// Spawn the extension stats background task
self.spawn_extension_stats_task();
if pspec.spec.prewarm_lfc_on_startup {
self.prewarm_lfc();
}
@@ -2188,41 +2199,6 @@ LIMIT 100",
info!("Pageserver config changed");
}
}
pub fn spawn_extension_stats_task(&self) {
let conf = self.tokio_conn_conf.clone();
let installed_extensions_collection_interval =
self.params.installed_extensions_collection_interval;
tokio::spawn(async move {
// An initial sleep is added to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
));
loop {
interval.tick().await;
let _ = installed_extensions(conf.clone()).await;
}
});
}
}
pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> {
let res = get_installed_extensions(conf).await;
match res {
Ok(extensions) => {
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&extensions).expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err:?}"),
}
Ok(())
}
pub fn forward_termination_signal() {

View File

@@ -2,24 +2,10 @@
module(load="imfile")
# Input configuration for log files in the specified directory
# The messages can be multiline. The start of the message is a timestamp
# in "%Y-%m-%d %H:%M:%S.%3N GMT" (so timezone hardcoded).
# Replace log_directory with the directory containing the log files
input(type="imfile" File="{log_directory}/*.log"
Tag="pgaudit_log" Severity="info" Facility="local5"
startmsg.regex="^[[:digit:]]{{4}}-[[:digit:]]{{2}}-[[:digit:]]{{2}} [[:digit:]]{{2}}:[[:digit:]]{{2}}:[[:digit:]]{{2}}.[[:digit:]]{{3}} GMT,")
# Replace {log_directory} with the directory containing the log files
input(type="imfile" File="{log_directory}/*.log" Tag="{tag}" Severity="info" Facility="local0")
# the directory to store rsyslog state files
global(workDirectory="/var/log/rsyslog")
# Construct json, endpoint_id and project_id as additional metadata
set $.json_log!endpoint_id = "{endpoint_id}";
set $.json_log!project_id = "{project_id}";
set $.json_log!msg = $msg;
# Template suitable for rfc5424 syslog format
template(name="PgAuditLog" type="string"
string="<%PRI%>1 %TIMESTAMP:::date-rfc3339% %HOSTNAME% - - - - %$.json_log%")
# Forward to remote syslog receiver (@@<hostname>:<port>;format
local5.info @@{remote_endpoint};PgAuditLog
# Forward logs to remote syslog server
*.* @@{remote_endpoint}

View File

@@ -84,15 +84,13 @@ fn restart_rsyslog() -> Result<()> {
pub fn configure_audit_rsyslog(
log_directory: String,
endpoint_id: &str,
project_id: &str,
tag: Option<String>,
remote_endpoint: &str,
) -> Result<()> {
let config_content: String = format!(
include_str!("config_template/compute_audit_rsyslog_template.conf"),
log_directory = log_directory,
endpoint_id = endpoint_id,
project_id = project_id,
tag = tag.unwrap_or("".to_string()),
remote_endpoint = remote_endpoint
);

View File

@@ -1279,7 +1279,6 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
mode: pageserver_api::models::TimelineCreateRequestMode::Branch {
ancestor_timeline_id,
ancestor_start_lsn: start_lsn,
read_only: false,
pg_version: None,
},
};

View File

@@ -20,7 +20,7 @@ first_path="$(ldconfig --verbose 2>/dev/null \
| grep --invert-match ^$'\t' \
| cut --delimiter=: --fields=1 \
| head --lines=1)"
test "$first_path" == '/usr/local/lib'
test "$first_path" == '/usr/local/lib' || true # Remove the || true in a follow-up PR. Needed for backwards compat.
echo "Waiting pageserver become ready."
while ! nc -z pageserver 6400; do

View File

@@ -1,16 +0,0 @@
#!/usr/bin/env bash
set -ex
cd "$(dirname "${0}")"
PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress
dropdb --if-exists contrib_regression
createdb contrib_regression
cd h3_postgis/test
psql -d contrib_regression -c "CREATE EXTENSION postgis" -c "CREATE EXTENSION postgis_raster" -c "CREATE EXTENSION h3" -c "CREATE EXTENSION h3_postgis"
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS}
cd ../../h3/test
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
dropdb --if-exists contrib_regression
createdb contrib_regression
psql -d contrib_regression -c "CREATE EXTENSION h3"
${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS}

View File

@@ -1,7 +0,0 @@
#!/bin/sh
set -ex
cd "$(dirname ${0})"
PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress
cd h3/test
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS}

View File

@@ -1,6 +0,0 @@
#!/bin/sh
set -ex
cd "$(dirname "${0}")"
if [ -f Makefile ]; then
make installcheck
fi

View File

@@ -1,9 +0,0 @@
#!/bin/sh
set -ex
cd "$(dirname ${0})"
[ -f Makefile ] || exit 0
dropdb --if-exist contrib_regression
createdb contrib_regression
PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS}

View File

@@ -82,8 +82,7 @@ EXTENSIONS='[
{"extname": "pg_ivm", "extdir": "pg_ivm-src"},
{"extname": "pgjwt", "extdir": "pgjwt-src"},
{"extname": "pgtap", "extdir": "pgtap-src"},
{"extname": "pg_repack", "extdir": "pg_repack-src"},
{"extname": "h3", "extdir": "h3-pg-src"}
{"extname": "pg_repack", "extdir": "pg_repack-src"}
]'
EXTNAMES=$(echo ${EXTENSIONS} | jq -r '.[].extname' | paste -sd ' ' -)
COMPUTE_TAG=${NEW_COMPUTE_TAG} docker compose --profile test-extensions up --quiet-pull --build -d

View File

@@ -45,21 +45,6 @@ pub struct NodeMetadata {
pub other: HashMap<String, serde_json::Value>,
}
/// PostHog integration config.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PostHogConfig {
/// PostHog project ID
pub project_id: String,
/// Server-side (private) API key
pub server_api_key: String,
/// Client-side (public) API key
pub client_api_key: String,
/// Private API URL
pub private_api_url: String,
/// Public API URL
pub public_api_url: String,
}
/// `pageserver.toml`
///
/// We use serde derive with `#[serde(default)]` to generate a deserializer
@@ -201,8 +186,6 @@ pub struct ConfigToml {
pub tracing: Option<Tracing>,
pub enable_tls_page_service_api: bool,
pub dev_mode: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub posthog_config: Option<PostHogConfig>,
pub timeline_import_config: TimelineImportConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub basebackup_cache_config: Option<BasebackupCacheConfig>,
@@ -718,7 +701,6 @@ impl Default for ConfigToml {
import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(),
},
basebackup_cache_config: None,
posthog_config: None,
}
}
}

View File

@@ -354,9 +354,6 @@ pub struct ShardImportProgressV1 {
pub completed: usize,
/// Hash of the plan
pub import_plan_hash: u64,
/// Soft limit for the job size
/// This needs to remain constant throughout the import
pub job_soft_size_limit: usize,
}
impl ShardImportStatus {
@@ -405,8 +402,6 @@ pub enum TimelineCreateRequestMode {
// using a flattened enum, so, it was an accepted field, and
// we continue to accept it by having it here.
pg_version: Option<u32>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
read_only: bool,
},
ImportPgdata {
import_pgdata: TimelineCreateRequestModeImportPgdata,

View File

@@ -6,14 +6,9 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
arc-swap.workspace = true
reqwest.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_json.workspace = true
sha2.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] }
tokio-util.workspace = true
tracing-utils.workspace = true
tracing.workspace = true
workspace_hack.workspace = true
thiserror.workspace = true

View File

@@ -1,59 +0,0 @@
//! A background loop that fetches feature flags from PostHog and updates the feature store.
use std::{sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use tokio_util::sync::CancellationToken;
use crate::{FeatureStore, PostHogClient, PostHogClientConfig};
/// A background loop that fetches feature flags from PostHog and updates the feature store.
pub struct FeatureResolverBackgroundLoop {
posthog_client: PostHogClient,
feature_store: ArcSwap<FeatureStore>,
cancel: CancellationToken,
}
impl FeatureResolverBackgroundLoop {
pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self {
Self {
posthog_client: PostHogClient::new(config),
feature_store: ArcSwap::new(Arc::new(FeatureStore::new())),
cancel: shutdown_pageserver,
}
}
pub fn spawn(self: Arc<Self>, handle: &tokio::runtime::Handle, refresh_period: Duration) {
let this = self.clone();
let cancel = self.cancel.clone();
handle.spawn(async move {
tracing::info!("Starting PostHog feature resolver");
let mut ticker = tokio::time::interval(refresh_period);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
}
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::warn!("Cannot get feature flags: {}", e);
continue;
}
};
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
}
tracing::info!("PostHog feature resolver stopped");
});
}
pub fn feature_store(&self) -> Arc<FeatureStore> {
self.feature_store.load_full()
}
}

View File

@@ -1,9 +1,5 @@
//! A lite version of the PostHog client that only supports local evaluation of feature flags.
mod background_loop;
pub use background_loop::FeatureResolverBackgroundLoop;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
@@ -24,7 +20,8 @@ pub enum PostHogEvaluationError {
#[derive(Deserialize)]
pub struct LocalEvaluationResponse {
pub flags: Vec<LocalEvaluationFlag>,
#[allow(dead_code)]
flags: Vec<LocalEvaluationFlag>,
}
#[derive(Deserialize)]
@@ -37,7 +34,7 @@ pub struct LocalEvaluationFlag {
#[derive(Deserialize)]
pub struct LocalEvaluationFlagFilters {
groups: Vec<LocalEvaluationFlagFilterGroup>,
multivariate: Option<LocalEvaluationFlagMultivariate>,
multivariate: LocalEvaluationFlagMultivariate,
}
#[derive(Deserialize)]
@@ -97,12 +94,6 @@ impl FeatureStore {
}
}
pub fn new_with_flags(flags: Vec<LocalEvaluationFlag>) -> Self {
let mut store = Self::new();
store.set_flags(flags);
store
}
pub fn set_flags(&mut self, flags: Vec<LocalEvaluationFlag>) {
self.flags.clear();
for flag in flags {
@@ -254,7 +245,7 @@ impl FeatureStore {
}
}
/// Evaluate a multivariate feature flag. Returns an error if the flag is not available or if there are errors
/// Evaluate a multivariate feature flag. Returns `None` if the flag is not available or if there are errors
/// during the evaluation.
///
/// The parsing logic is as follows:
@@ -272,15 +263,10 @@ impl FeatureStore {
/// Example: we have a multivariate flag with 3 groups of the configured global rollout percentage: A (10%), B (20%), C (70%).
/// There is a single group with a condition that has a rollout percentage of 10% and it does not have a variant override.
/// Then, we will have 1% of the users evaluated to A, 2% to B, and 7% to C.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_multivariate(
&self,
flag_key: &str,
user_id: &str,
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<String, PostHogEvaluationError> {
let hash_on_global_rollout_percentage =
Self::consistent_hash(user_id, flag_key, "multivariate");
@@ -290,39 +276,10 @@ impl FeatureStore {
flag_key,
hash_on_global_rollout_percentage,
hash_on_group_rollout_percentage,
properties,
&HashMap::new(),
)
}
/// Evaluate a boolean feature flag. Returns an error if the flag is not available or if there are errors
/// during the evaluation.
///
/// The parsing logic is as follows:
///
/// * Generate a consistent hash for the tenant-feature.
/// * Match each filter group.
/// - If a group is matched, it will first determine whether the user is in the range of the rollout
/// percentage.
/// - If the hash falls within the group's rollout percentage, return true.
/// * Otherwise, continue with the next group until all groups are evaluated and no group is within the
/// rollout percentage.
/// * If there are no matching groups, return an error.
///
/// Returns `Ok(())` if the feature flag evaluates to true. In the future, it will return a payload.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_boolean(
&self,
flag_key: &str,
user_id: &str,
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<(), PostHogEvaluationError> {
let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "boolean");
self.evaluate_boolean_inner(flag_key, hash_on_global_rollout_percentage, properties)
}
/// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID
/// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests
/// and avoid duplicate computations.
@@ -349,11 +306,6 @@ impl FeatureStore {
flag_key
)));
}
let Some(ref multivariate) = flag_config.filters.multivariate else {
return Err(PostHogEvaluationError::Internal(format!(
"No multivariate available, should use evaluate_boolean?: {flag_key}"
)));
};
// TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog
// Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it
// does not matter.
@@ -362,7 +314,7 @@ impl FeatureStore {
GroupEvaluationResult::MatchedAndOverride(variant) => return Ok(variant),
GroupEvaluationResult::MatchedAndEvaluate => {
let mut percentage = 0;
for variant in &multivariate.variants {
for variant in &flag_config.filters.multivariate.variants {
percentage += variant.rollout_percentage;
if self
.evaluate_percentage(hash_on_global_rollout_percentage, percentage)
@@ -390,77 +342,6 @@ impl FeatureStore {
)))
}
}
/// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID
/// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests
/// and avoid duplicate computations.
///
/// Use a different consistent hash for evaluating the group rollout percentage.
/// The behavior: if the condition is set to rolling out to 10% of the users, and
/// we set the variant A to 20% in the global config, then 2% of the total users will
/// be evaluated to variant A.
///
/// Note that the hash to determine group rollout percentage is shared across all groups. So if we have two
/// exactly-the-same conditions with 10% and 20% rollout percentage respectively, a total of 20% of the users
/// will be evaluated (versus 30% if group evaluation is done independently).
pub(crate) fn evaluate_boolean_inner(
&self,
flag_key: &str,
hash_on_global_rollout_percentage: f64,
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<(), PostHogEvaluationError> {
if let Some(flag_config) = self.flags.get(flag_key) {
if !flag_config.active {
return Err(PostHogEvaluationError::NotAvailable(format!(
"The feature flag is not active: {}",
flag_key
)));
}
if flag_config.filters.multivariate.is_some() {
return Err(PostHogEvaluationError::Internal(format!(
"This looks like a multivariate flag, should use evaluate_multivariate?: {flag_key}"
)));
};
// TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog
// Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it
// does not matter.
for group in &flag_config.filters.groups {
match self.evaluate_group(group, hash_on_global_rollout_percentage, properties)? {
GroupEvaluationResult::MatchedAndOverride(_) => {
return Err(PostHogEvaluationError::Internal(format!(
"Boolean flag cannot have overrides: {}",
flag_key
)));
}
GroupEvaluationResult::MatchedAndEvaluate => {
return Ok(());
}
GroupEvaluationResult::Unmatched => continue,
}
}
// If no group is matched, the feature is not available, and up to the caller to decide what to do.
Err(PostHogEvaluationError::NoConditionGroupMatched)
} else {
// The feature flag is not available yet
Err(PostHogEvaluationError::NotAvailable(format!(
"Not found in the local evaluation spec: {}",
flag_key
)))
}
}
}
pub struct PostHogClientConfig {
/// The server API key.
pub server_api_key: String,
/// The client API key.
pub client_api_key: String,
/// The project ID.
pub project_id: String,
/// The private API URL.
pub private_api_url: String,
/// The public API URL.
pub public_api_url: String,
}
/// A lite PostHog client.
@@ -479,16 +360,37 @@ pub struct PostHogClientConfig {
/// want to report the feature flag usage back to PostHog. The current plan is to use PostHog only as an UI to
/// configure feature flags so it is very likely that the client API will not be used.
pub struct PostHogClient {
/// The config.
config: PostHogClientConfig,
/// The server API key.
server_api_key: String,
/// The client API key.
client_api_key: String,
/// The project ID.
project_id: String,
/// The private API URL.
private_api_url: String,
/// The public API URL.
public_api_url: String,
/// The HTTP client.
client: reqwest::Client,
}
impl PostHogClient {
pub fn new(config: PostHogClientConfig) -> Self {
pub fn new(
server_api_key: String,
client_api_key: String,
project_id: String,
private_api_url: String,
public_api_url: String,
) -> Self {
let client = reqwest::Client::new();
Self { config, client }
Self {
server_api_key,
client_api_key,
project_id,
private_api_url,
public_api_url,
client,
}
}
pub fn new_with_us_region(
@@ -496,13 +398,13 @@ impl PostHogClient {
client_api_key: String,
project_id: String,
) -> Self {
Self::new(PostHogClientConfig {
Self::new(
server_api_key,
client_api_key,
project_id,
private_api_url: "https://us.posthog.com".to_string(),
public_api_url: "https://us.i.posthog.com".to_string(),
})
"https://us.posthog.com".to_string(),
"https://us.i.posthog.com".to_string(),
)
}
/// Fetch the feature flag specs from the server.
@@ -520,12 +422,12 @@ impl PostHogClient {
// with bearer token of self.server_api_key
let url = format!(
"{}/api/projects/{}/feature_flags/local_evaluation",
self.config.private_api_url, self.config.project_id
self.private_api_url, self.project_id
);
let response = self
.client
.get(url)
.bearer_auth(&self.config.server_api_key)
.bearer_auth(&self.server_api_key)
.send()
.await?;
let body = response.text().await?;
@@ -544,11 +446,11 @@ impl PostHogClient {
) -> anyhow::Result<()> {
// PUBLIC_URL/capture/
// with bearer token of self.client_api_key
let url = format!("{}/capture/", self.config.public_api_url);
let url = format!("{}/capture/", self.public_api_url);
self.client
.post(url)
.body(serde_json::to_string(&json!({
"api_key": self.config.client_api_key,
"api_key": self.client_api_key,
"distinct_id": distinct_id,
"event": event,
"properties": properties,
@@ -565,162 +467,95 @@ mod tests {
fn data() -> &'static str {
r#"{
"flags": [
{
"id": 141807,
"team_id": 152860,
"name": "",
"key": "image-compaction-boundary",
"filters": {
"groups": [
{
"variant": null,
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
}
"flags": [
{
"id": 132794,
"team_id": 152860,
"name": "",
"key": "gc-compaction",
"filters": {
"groups": [
{
"variant": "enabled-stage-2",
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 50
},
{
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 80
}
],
"payloads": {},
"multivariate": {
"variants": [
{
"key": "disabled",
"name": "",
"rollout_percentage": 90
},
{
"key": "enabled-stage-1",
"name": "",
"rollout_percentage": 10
},
{
"key": "enabled-stage-2",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled-stage-3",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled",
"name": "",
"rollout_percentage": 0
}
]
}
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 6
}
],
"rollout_percentage": 40
},
{
"variant": null,
"properties": [],
"rollout_percentage": 10
}
],
"payloads": {},
"multivariate": null
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 1
},
{
"id": 135586,
"team_id": 152860,
"name": "",
"key": "boolean-flag",
"filters": {
"groups": [
{
"variant": null,
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
}
],
"rollout_percentage": 47
}
],
"payloads": {},
"multivariate": null
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 1
},
{
"id": 132794,
"team_id": 152860,
"name": "",
"key": "gc-compaction",
"filters": {
"groups": [
{
"variant": "enabled-stage-2",
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 50
},
{
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 80
}
],
"payloads": {},
"multivariate": {
"variants": [
{
"key": "disabled",
"name": "",
"rollout_percentage": 90
},
{
"key": "enabled-stage-1",
"name": "",
"rollout_percentage": 10
},
{
"key": "enabled-stage-2",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled-stage-3",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled",
"name": "",
"rollout_percentage": 0
}
]
}
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 7
}
],
"group_type_mapping": {},
"cohorts": {}
}"#
"group_type_mapping": {},
"cohorts": {}
}"#
}
#[test]
@@ -796,125 +631,4 @@ mod tests {
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
}
#[test]
fn evaluate_boolean_1() {
// The `boolean-flag` feature flag only has one group that matches on the free user.
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new());
assert!(matches!(
variant,
Err(PostHogEvaluationError::NotAvailable(_))
),);
let properties_unmatched = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("paid".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// This does not match any group so there will be an error.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties_unmatched);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
let properties = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("free".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// It matches the first group as 0.10 <= 0.50 and the properties are matched. Then it gets evaluated to the variant override.
let variant = store.evaluate_boolean_inner("boolean-flag", 0.10, &properties);
assert!(variant.is_ok());
// It matches the group conditions but not the group rollout percentage.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
}
#[test]
fn evaluate_boolean_2() {
// The `image-compaction-boundary` feature flag has one group that matches on the free user and a group that matches on all users.
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant =
store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &HashMap::new());
assert!(matches!(
variant,
Err(PostHogEvaluationError::NotAvailable(_))
),);
let properties_unmatched = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("paid".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// This does not match the filtered group but the all user group.
let variant =
store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties_unmatched);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
let variant =
store.evaluate_boolean_inner("image-compaction-boundary", 0.05, &properties_unmatched);
assert!(variant.is_ok());
let properties = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("free".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// It matches the first group as 0.30 <= 0.40 and the properties are matched. Then it gets evaluated to the variant override.
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.30, &properties);
assert!(variant.is_ok());
// It matches the group conditions but not the group rollout percentage.
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
// It matches the second "all" group conditions.
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.09, &properties);
assert!(variant.is_ok());
}
}

View File

@@ -1,7 +1,6 @@
#![allow(clippy::todo)]
use std::ffi::CString;
use std::str::FromStr;
use postgres_ffi::WAL_SEGMENT_SIZE;
use utils::id::TenantTimelineId;
@@ -174,8 +173,6 @@ pub struct Config {
pub ttid: TenantTimelineId,
/// List of safekeepers in format `host:port`
pub safekeepers_list: Vec<String>,
/// libpq connection info options
pub safekeeper_conninfo_options: String,
/// Safekeeper reconnect timeout in milliseconds
pub safekeeper_reconnect_timeout: i32,
/// Safekeeper connection timeout in milliseconds
@@ -205,9 +202,6 @@ impl Wrapper {
.into_bytes_with_nul();
assert!(safekeepers_list_vec.len() == safekeepers_list_vec.capacity());
let safekeepers_list = safekeepers_list_vec.as_mut_ptr() as *mut std::ffi::c_char;
let safekeeper_conninfo_options = CString::from_str(&config.safekeeper_conninfo_options)
.unwrap()
.into_raw();
let callback_data = Box::into_raw(Box::new(api)) as *mut ::std::os::raw::c_void;
@@ -215,7 +209,6 @@ impl Wrapper {
neon_tenant,
neon_timeline,
safekeepers_list,
safekeeper_conninfo_options,
safekeeper_reconnect_timeout: config.safekeeper_reconnect_timeout,
safekeeper_connection_timeout: config.safekeeper_connection_timeout,
wal_segment_size: WAL_SEGMENT_SIZE as i32, // default 16MB
@@ -583,7 +576,6 @@ mod tests {
let config = crate::walproposer::Config {
ttid,
safekeepers_list: vec!["localhost:5000".to_string()],
safekeeper_conninfo_options: String::new(),
safekeeper_reconnect_timeout: 1000,
safekeeper_connection_timeout: 10000,
sync_safekeepers: true,

View File

@@ -17,69 +17,51 @@ anyhow.workspace = true
arc-swap.workspace = true
async-compression.workspace = true
async-stream.workspace = true
bincode.workspace = true
bit_field.workspace = true
bincode.workspace = true
byteorder.workspace = true
bytes.workspace = true
camino-tempfile.workspace = true
camino.workspace = true
camino-tempfile.workspace = true
chrono = { workspace = true, features = ["serde"] }
clap = { workspace = true, features = ["string"] }
consumption_metrics.workspace = true
crc32c.workspace = true
either.workspace = true
enum-map.workspace = true
enumset = { workspace = true, features = ["serde"]}
fail.workspace = true
futures.workspace = true
hashlink.workspace = true
hex.workspace = true
http-utils.workspace = true
humantime-serde.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
hyper0.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
md5.workspace = true
metrics.workspace = true
nix.workspace = true
num_cpus.workspace = true # hack to get the number of worker threads tokio uses
# hack to get the number of worker threads tokio uses
num_cpus.workspace = true
num-traits.workspace = true
once_cell.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that
pageserver_compaction.workspace = true
pageserver_page_api.workspace = true
pem.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
postgres_connection.workspace = true
postgres_ffi.workspace = true
postgres_initdb.workspace = true
postgres-protocol.workspace = true
postgres-types.workspace = true
posthog_client_lite.workspace = true
postgres_initdb.workspace = true
pprof.workspace = true
pq_proto.workspace = true
rand.workspace = true
range-set-blaze = { version = "0.1.16", features = ["alloc"] }
regex.workspace = true
remote_storage.workspace = true
reqwest.workspace = true
rpds.workspace = true
rustls.workspace = true
scopeguard.workspace = true
send-future.workspace = true
serde.workspace = true
serde_json = { workspace = true, features = ["raw_value"] }
serde_path_to_error.workspace = true
serde_with.workspace = true
serde.workspace = true
smallvec.workspace = true
storage_broker.workspace = true
strum_macros.workspace = true
strum.workspace = true
sysinfo.workspace = true
tenant_size_model.workspace = true
tokio-tar.workspace = true
thiserror.workspace = true
tikv-jemallocator.workspace = true
tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] }
@@ -88,7 +70,6 @@ tokio-io-timeout.workspace = true
tokio-postgres.workspace = true
tokio-rustls.workspace = true
tokio-stream.workspace = true
tokio-tar.workspace = true
tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tonic.workspace = true
@@ -96,10 +77,29 @@ tonic-reflection.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
url.workspace = true
utils.workspace = true
wal_decoder.workspace = true
walkdir.workspace = true
metrics.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that
pageserver_compaction.workspace = true
pem.workspace = true
postgres_connection.workspace = true
postgres_ffi.workspace = true
pq_proto.workspace = true
remote_storage.workspace = true
storage_broker.workspace = true
tenant_size_model.workspace = true
http-utils.workspace = true
utils.workspace = true
workspace_hack.workspace = true
reqwest.workspace = true
rpds.workspace = true
enum-map.workspace = true
enumset = { workspace = true, features = ["serde"]}
strum.workspace = true
strum_macros.workspace = true
wal_decoder.workspace = true
smallvec.workspace = true
twox-hash.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]

View File

@@ -5,14 +5,8 @@ edition.workspace = true
license.workspace = true
[dependencies]
bytes.workspace = true
pageserver_api.workspace = true
postgres_ffi.workspace = true
prost.workspace = true
smallvec.workspace = true
thiserror.workspace = true
tonic.workspace = true
utils.workspace = true
workspace_hack.workspace = true
[build-dependencies]

View File

@@ -54,9 +54,9 @@ service PageService {
// RPCs use regular unary requests, since they are not as frequent and
// performance-critical, and this simplifies implementation.
//
// NB: a gRPC status response (e.g. errors) will terminate the stream. The
// stream may be shared by multiple Postgres backends, so we avoid this by
// sending them as GetPageResponse.status_code instead.
// NB: a status response (e.g. errors) will terminate the stream. The stream
// may be shared by e.g. multiple Postgres backends, so we should avoid this.
// Most errors are therefore sent as GetPageResponse.status instead.
rpc GetPages (stream GetPageRequest) returns (stream GetPageResponse);
// Returns the size of a relation, as # of blocks.
@@ -159,8 +159,8 @@ message GetPageRequest {
// A GetPageRequest class. Primarily intended for observability, but may also be
// used for prioritization in the future.
enum GetPageClass {
// Unknown class. For backwards compatibility: used when an older client version sends a class
// that a newer server version has removed.
// Unknown class. For forwards compatibility: used when the client sends a
// class that the server doesn't know about.
GET_PAGE_CLASS_UNKNOWN = 0;
// A normal request. This is the default.
GET_PAGE_CLASS_NORMAL = 1;
@@ -180,37 +180,31 @@ message GetPageResponse {
// The original request's ID.
uint64 request_id = 1;
// The response status code.
GetPageStatusCode status_code = 2;
GetPageStatus status = 2;
// A string describing the status, if any.
string reason = 3;
// The 8KB page images, in the same order as the request. Empty if status_code != OK.
// The 8KB page images, in the same order as the request. Empty if status != OK.
repeated bytes page_image = 4;
}
// A GetPageResponse status code.
//
// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream
// (potentially shared by many backends), and a gRPC status response would terminate the stream so
// we send GetPageResponse messages with these codes instead.
enum GetPageStatusCode {
// Unknown status. For forwards compatibility: used when an older client version receives a new
// status code from a newer server version.
GET_PAGE_STATUS_CODE_UNKNOWN = 0;
// A GetPageResponse status code. Since we use a bidirectional stream, we don't
// want to send errors as gRPC statuses, since this would terminate the stream.
enum GetPageStatus {
// Unknown status. For forwards compatibility: used when the server sends a
// status code that the client doesn't know about.
GET_PAGE_STATUS_UNKNOWN = 0;
// The request was successful.
GET_PAGE_STATUS_CODE_OK = 1;
GET_PAGE_STATUS_OK = 1;
// The page did not exist. The tenant/timeline/shard has already been
// validated during stream setup.
GET_PAGE_STATUS_CODE_NOT_FOUND = 2;
GET_PAGE_STATUS_NOT_FOUND = 2;
// The request was invalid.
GET_PAGE_STATUS_CODE_INVALID_REQUEST = 3;
// The request failed due to an internal server error.
GET_PAGE_STATUS_CODE_INTERNAL_ERROR = 4;
GET_PAGE_STATUS_INVALID = 3;
// The tenant is rate limited. Slow down and retry later.
GET_PAGE_STATUS_CODE_SLOW_DOWN = 5;
// NB: shutdown errors are emitted as a gRPC Unavailable status.
//
// TODO: consider adding a GET_PAGE_STATUS_CODE_LAYER_DOWNLOAD in the case of a layer download.
// This could free up the server task to process other requests while the download is in progress.
GET_PAGE_STATUS_SLOW_DOWN = 4;
// TODO: consider adding a GET_PAGE_STATUS_LAYER_DOWNLOAD in the case of a
// layer download. This could free up the server task to process other
// requests while the layer download is in progress.
}
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on

View File

@@ -17,7 +17,3 @@ pub mod proto {
pub use page_service_client::PageServiceClient;
pub use page_service_server::{PageService, PageServiceServer};
}
mod model;
pub use model::*;

View File

@@ -1,595 +0,0 @@
//! Structs representing the canonical page service API.
//!
//! These mirror the autogenerated Protobuf types. The differences are:
//!
//! - Types that are in fact required by the API are not Options. The protobuf "required"
//! attribute is deprecated and 'prost' marks a lot of members as optional because of that.
//! (See <https://github.com/tokio-rs/prost/issues/800> for a gripe on this)
//!
//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits.
//!
//! - Validate protocol invariants, via try_from() and try_into().
use bytes::Bytes;
use postgres_ffi::Oid;
use smallvec::SmallVec;
// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid
// pulling in all of their other crate dependencies when building the client.
use utils::lsn::Lsn;
use crate::proto;
/// A protocol error. Typically returned via try_from() or try_into().
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
#[error("field '{0}' has invalid value '{1}'")]
Invalid(&'static str, String),
#[error("required field '{0}' is missing")]
Missing(&'static str),
}
impl ProtocolError {
/// Helper to generate a new ProtocolError::Invalid for the given field and value.
pub fn invalid(field: &'static str, value: impl std::fmt::Debug) -> Self {
Self::Invalid(field, format!("{value:?}"))
}
}
impl From<ProtocolError> for tonic::Status {
fn from(err: ProtocolError) -> Self {
tonic::Status::invalid_argument(format!("{err}"))
}
}
/// The LSN a request should read at.
#[derive(Clone, Copy, Debug)]
pub struct ReadLsn {
/// The request's read LSN.
pub request_lsn: Lsn,
/// If given, the caller guarantees that the page has not been modified since this LSN. Must be
/// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page
/// without waiting for the request LSN to arrive. Valid for all request types.
///
/// It is undefined behaviour to make a request such that the page was, in fact, modified
/// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an
/// error, or it might return the old page version or the new page version. Setting
/// not_modified_since_lsn equal to request_lsn is always safe, but can lead to unnecessary
/// waiting.
pub not_modified_since_lsn: Option<Lsn>,
}
impl ReadLsn {
/// Validates the ReadLsn.
pub fn validate(&self) -> Result<(), ProtocolError> {
if self.request_lsn == Lsn::INVALID {
return Err(ProtocolError::invalid("request_lsn", self.request_lsn));
}
if self.not_modified_since_lsn > Some(self.request_lsn) {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
self.not_modified_since_lsn,
));
}
Ok(())
}
}
impl TryFrom<proto::ReadLsn> for ReadLsn {
type Error = ProtocolError;
fn try_from(pb: proto::ReadLsn) -> Result<Self, Self::Error> {
let read_lsn = Self {
request_lsn: Lsn(pb.request_lsn),
not_modified_since_lsn: match pb.not_modified_since_lsn {
0 => None,
lsn => Some(Lsn(lsn)),
},
};
read_lsn.validate()?;
Ok(read_lsn)
}
}
impl TryFrom<ReadLsn> for proto::ReadLsn {
type Error = ProtocolError;
fn try_from(read_lsn: ReadLsn) -> Result<Self, Self::Error> {
read_lsn.validate()?;
Ok(Self {
request_lsn: read_lsn.request_lsn.0,
not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0,
})
}
}
// RelTag is defined in pageserver_api::reltag.
pub type RelTag = pageserver_api::reltag::RelTag;
impl TryFrom<proto::RelTag> for RelTag {
type Error = ProtocolError;
fn try_from(pb: proto::RelTag) -> Result<Self, Self::Error> {
Ok(Self {
spcnode: pb.spc_oid,
dbnode: pb.db_oid,
relnode: pb.rel_number,
forknum: pb
.fork_number
.try_into()
.map_err(|_| ProtocolError::invalid("fork_number", pb.fork_number))?,
})
}
}
impl From<RelTag> for proto::RelTag {
fn from(rel_tag: RelTag) -> Self {
Self {
spc_oid: rel_tag.spcnode,
db_oid: rel_tag.dbnode,
rel_number: rel_tag.relnode,
fork_number: rel_tag.forknum as u32,
}
}
}
/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error.
#[derive(Clone, Copy, Debug)]
pub struct CheckRelExistsRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
}
impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
type Error = ProtocolError;
fn try_from(pb: proto::CheckRelExistsRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
})
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
fn from(pb: proto::CheckRelExistsResponse) -> Self {
pb.exists
}
}
impl From<CheckRelExistsResponse> for proto::CheckRelExistsResponse {
fn from(exists: CheckRelExistsResponse) -> Self {
Self { exists }
}
}
/// Requests a base backup at a given LSN.
#[derive(Clone, Copy, Debug)]
pub struct GetBaseBackupRequest {
/// The LSN to fetch a base backup at.
pub read_lsn: ReadLsn,
/// If true, logical replication slots will not be created.
pub replica: bool,
}
impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
replica: pb.replica,
})
}
}
impl TryFrom<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(request: GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
replica: request.replica,
})
}
}
pub type GetBaseBackupResponseChunk = Bytes;
impl TryFrom<proto::GetBaseBackupResponseChunk> for GetBaseBackupResponseChunk {
type Error = ProtocolError;
fn try_from(pb: proto::GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
if pb.chunk.is_empty() {
return Err(ProtocolError::Missing("chunk"));
}
Ok(pb.chunk)
}
}
impl TryFrom<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
type Error = ProtocolError;
fn try_from(chunk: GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
if chunk.is_empty() {
return Err(ProtocolError::Missing("chunk"));
}
Ok(Self { chunk })
}
}
/// Requests the size of a database, as # of bytes. Only valid on shard 0, other shards will error.
#[derive(Clone, Copy, Debug)]
pub struct GetDbSizeRequest {
pub read_lsn: ReadLsn,
pub db_oid: Oid,
}
impl TryFrom<proto::GetDbSizeRequest> for GetDbSizeRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetDbSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
db_oid: pb.db_oid,
})
}
}
impl TryFrom<GetDbSizeRequest> for proto::GetDbSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetDbSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
db_oid: request.db_oid,
})
}
}
pub type GetDbSizeResponse = u64;
impl From<proto::GetDbSizeResponse> for GetDbSizeResponse {
fn from(pb: proto::GetDbSizeResponse) -> Self {
pb.num_bytes
}
}
impl From<GetDbSizeResponse> for proto::GetDbSizeResponse {
fn from(num_bytes: GetDbSizeResponse) -> Self {
Self { num_bytes }
}
}
/// Requests one or more pages.
#[derive(Clone, Debug)]
pub struct GetPageRequest {
/// A request ID. Will be included in the response. Should be unique for in-flight requests on
/// the stream.
pub request_id: RequestID,
/// The request class.
pub request_class: GetPageClass,
/// The LSN to read at.
pub read_lsn: ReadLsn,
/// The relation to read from.
pub rel: RelTag,
/// Page numbers to read. Must belong to the remote shard.
///
/// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access
/// costs and parallelizing them. This may increase the latency of any individual request, but
/// improves the overall latency and throughput of the batch as a whole.
pub block_numbers: SmallVec<[u32; 1]>,
}
impl TryFrom<proto::GetPageRequest> for GetPageRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetPageRequest) -> Result<Self, Self::Error> {
if pb.block_number.is_empty() {
return Err(ProtocolError::Missing("block_number"));
}
Ok(Self {
request_id: pb.request_id,
request_class: pb.request_class.into(),
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
block_numbers: pb.block_number.into(),
})
}
}
impl TryFrom<GetPageRequest> for proto::GetPageRequest {
type Error = ProtocolError;
fn try_from(request: GetPageRequest) -> Result<Self, Self::Error> {
if request.block_numbers.is_empty() {
return Err(ProtocolError::Missing("block_number"));
}
Ok(Self {
request_id: request.request_id,
request_class: request.request_class.into(),
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
block_number: request.block_numbers.into_vec(),
})
}
}
/// A GetPage request ID.
pub type RequestID = u64;
/// A GetPage request class.
#[derive(Clone, Copy, Debug)]
pub enum GetPageClass {
/// Unknown class. For backwards compatibility: used when an older client version sends a class
/// that a newer server version has removed.
Unknown,
/// A normal request. This is the default.
Normal,
/// A prefetch request. NB: can only be classified on pg < 18.
Prefetch,
/// A background request (e.g. vacuum).
Background,
}
impl From<proto::GetPageClass> for GetPageClass {
fn from(pb: proto::GetPageClass) -> Self {
match pb {
proto::GetPageClass::Unknown => Self::Unknown,
proto::GetPageClass::Normal => Self::Normal,
proto::GetPageClass::Prefetch => Self::Prefetch,
proto::GetPageClass::Background => Self::Background,
}
}
}
impl From<i32> for GetPageClass {
fn from(class: i32) -> Self {
proto::GetPageClass::try_from(class)
.unwrap_or(proto::GetPageClass::Unknown)
.into()
}
}
impl From<GetPageClass> for proto::GetPageClass {
fn from(class: GetPageClass) -> Self {
match class {
GetPageClass::Unknown => Self::Unknown,
GetPageClass::Normal => Self::Normal,
GetPageClass::Prefetch => Self::Prefetch,
GetPageClass::Background => Self::Background,
}
}
}
impl From<GetPageClass> for i32 {
fn from(class: GetPageClass) -> Self {
proto::GetPageClass::from(class).into()
}
}
/// A GetPage response.
///
/// A batch response will contain all of the requested pages. We could eagerly emit individual pages
/// as soon as they are ready, but on a readv() Postgres holds buffer pool locks on all pages in the
/// batch and we'll only return once the entire batch is ready, so no one can make use of the
/// individual pages.
#[derive(Clone, Debug)]
pub struct GetPageResponse {
/// The original request's ID.
pub request_id: RequestID,
/// The response status code.
pub status_code: GetPageStatusCode,
/// A string describing the status, if any.
pub reason: Option<String>,
/// The 8KB page images, in the same order as the request. Empty if status != OK.
pub page_images: SmallVec<[Bytes; 1]>,
}
impl From<proto::GetPageResponse> for GetPageResponse {
fn from(pb: proto::GetPageResponse) -> Self {
Self {
request_id: pb.request_id,
status_code: pb.status_code.into(),
reason: Some(pb.reason).filter(|r| !r.is_empty()),
page_images: pb.page_image.into(),
}
}
}
impl From<GetPageResponse> for proto::GetPageResponse {
fn from(response: GetPageResponse) -> Self {
Self {
request_id: response.request_id,
status_code: response.status_code.into(),
reason: response.reason.unwrap_or_default(),
page_image: response.page_images.into_vec(),
}
}
}
/// A GetPage response status code.
///
/// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream
/// (potentially shared by many backends), and a gRPC status response would terminate the stream so
/// we send GetPageResponse messages with these codes instead.
#[derive(Clone, Copy, Debug)]
pub enum GetPageStatusCode {
/// Unknown status. For forwards compatibility: used when an older client version receives a new
/// status code from a newer server version.
Unknown,
/// The request was successful.
Ok,
/// The page did not exist. The tenant/timeline/shard has already been validated during stream
/// setup.
NotFound,
/// The request was invalid.
InvalidRequest,
/// The request failed due to an internal server error.
InternalError,
/// The tenant is rate limited. Slow down and retry later.
SlowDown,
}
impl From<proto::GetPageStatusCode> for GetPageStatusCode {
fn from(pb: proto::GetPageStatusCode) -> Self {
match pb {
proto::GetPageStatusCode::Unknown => Self::Unknown,
proto::GetPageStatusCode::Ok => Self::Ok,
proto::GetPageStatusCode::NotFound => Self::NotFound,
proto::GetPageStatusCode::InvalidRequest => Self::InvalidRequest,
proto::GetPageStatusCode::InternalError => Self::InternalError,
proto::GetPageStatusCode::SlowDown => Self::SlowDown,
}
}
}
impl From<i32> for GetPageStatusCode {
fn from(status_code: i32) -> Self {
proto::GetPageStatusCode::try_from(status_code)
.unwrap_or(proto::GetPageStatusCode::Unknown)
.into()
}
}
impl From<GetPageStatusCode> for proto::GetPageStatusCode {
fn from(status_code: GetPageStatusCode) -> Self {
match status_code {
GetPageStatusCode::Unknown => Self::Unknown,
GetPageStatusCode::Ok => Self::Ok,
GetPageStatusCode::NotFound => Self::NotFound,
GetPageStatusCode::InvalidRequest => Self::InvalidRequest,
GetPageStatusCode::InternalError => Self::InternalError,
GetPageStatusCode::SlowDown => Self::SlowDown,
}
}
}
impl From<GetPageStatusCode> for i32 {
fn from(status_code: GetPageStatusCode) -> Self {
proto::GetPageStatusCode::from(status_code).into()
}
}
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other
// shards will error.
pub struct GetRelSizeRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
}
impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
type Error = ProtocolError;
fn try_from(proto: proto::GetRelSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: proto
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
})
}
}
impl TryFrom<GetRelSizeRequest> for proto::GetRelSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetRelSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
})
}
}
pub type GetRelSizeResponse = u32;
impl From<proto::GetRelSizeResponse> for GetRelSizeResponse {
fn from(proto: proto::GetRelSizeResponse) -> Self {
proto.num_blocks
}
}
impl From<GetRelSizeResponse> for proto::GetRelSizeResponse {
fn from(num_blocks: GetRelSizeResponse) -> Self {
Self { num_blocks }
}
}
/// Requests an SLRU segment. Only valid on shard 0, other shards will error.
pub struct GetSlruSegmentRequest {
pub read_lsn: ReadLsn,
pub kind: SlruKind,
pub segno: u32,
}
impl TryFrom<proto::GetSlruSegmentRequest> for GetSlruSegmentRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetSlruSegmentRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
kind: u8::try_from(pb.kind)
.ok()
.and_then(SlruKind::from_repr)
.ok_or_else(|| ProtocolError::invalid("slru_kind", pb.kind))?,
segno: pb.segno,
})
}
}
impl TryFrom<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
type Error = ProtocolError;
fn try_from(request: GetSlruSegmentRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
kind: request.kind as u32,
segno: request.segno,
})
}
}
pub type GetSlruSegmentResponse = Bytes;
impl TryFrom<proto::GetSlruSegmentResponse> for GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(pb: proto::GetSlruSegmentResponse) -> Result<Self, Self::Error> {
if pb.segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}
Ok(pb.segment)
}
}
impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
if segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}
Ok(Self { segment })
}
}
// SlruKind is defined in pageserver_api::reltag.
pub type SlruKind = pageserver_api::reltag::SlruKind;

View File

@@ -21,7 +21,6 @@ use pageserver::config::{PageServerConf, PageserverIdentity, ignored_fields};
use pageserver::controller_upcall_client::StorageControllerUpcallClient;
use pageserver::deletion_queue::DeletionQueue;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::feature_resolver::FeatureResolver;
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::task_mgr::{
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
@@ -523,12 +522,6 @@ fn start_pageserver(
// Set up remote storage client
let remote_storage = BACKGROUND_RUNTIME.block_on(create_remote_storage_client(conf))?;
let feature_resolver = create_feature_resolver(
conf,
shutdown_pageserver.clone(),
BACKGROUND_RUNTIME.handle(),
)?;
// Set up deletion queue
let (deletion_queue, deletion_workers) = DeletionQueue::new(
remote_storage.clone(),
@@ -582,7 +575,6 @@ fn start_pageserver(
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
@@ -857,14 +849,6 @@ fn start_pageserver(
})
}
fn create_feature_resolver(
conf: &'static PageServerConf,
shutdown_pageserver: CancellationToken,
handle: &tokio::runtime::Handle,
) -> anyhow::Result<FeatureResolver> {
FeatureResolver::spawn(conf, shutdown_pageserver, handle)
}
async fn create_remote_storage_client(
conf: &'static PageServerConf,
) -> anyhow::Result<GenericRemoteStorage> {

View File

@@ -14,7 +14,7 @@ use std::time::Duration;
use anyhow::{Context, bail, ensure};
use camino::{Utf8Path, Utf8PathBuf};
use once_cell::sync::OnceCell;
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig};
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes};
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pem::Pem;
@@ -238,9 +238,6 @@ pub struct PageServerConf {
/// This is insecure and should only be used in development environments.
pub dev_mode: bool,
/// PostHog integration config.
pub posthog_config: Option<PostHogConfig>,
pub timeline_import_config: pageserver_api::config::TimelineImportConfig,
pub basebackup_cache_config: Option<pageserver_api::config::BasebackupCacheConfig>,
@@ -424,7 +421,6 @@ impl PageServerConf {
tracing,
enable_tls_page_service_api,
dev_mode,
posthog_config,
timeline_import_config,
basebackup_cache_config,
} = config_toml;
@@ -540,7 +536,6 @@ impl PageServerConf {
}
None => Vec::new(),
},
posthog_config,
};
// ------------------------------------------------------------

View File

@@ -837,30 +837,7 @@ async fn collect_eviction_candidates(
continue;
}
let info = tl.get_local_layers_for_disk_usage_eviction().await;
debug!(
tenant_id=%tl.tenant_shard_id.tenant_id,
shard_id=%tl.tenant_shard_id.shard_slug(),
timeline_id=%tl.timeline_id,
"timeline resident layers count: {}", info.resident_layers.len()
);
tenant_candidates.extend(info.resident_layers.into_iter());
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));
if cancel.is_cancelled() {
return Ok(EvictionCandidates::Cancelled);
}
}
// Also consider layers of timelines being imported for eviction
for tl in tenant.list_importing_timelines() {
let info = tl.timeline.get_local_layers_for_disk_usage_eviction().await;
debug!(
tenant_id=%tl.timeline.tenant_shard_id.tenant_id,
shard_id=%tl.timeline.tenant_shard_id.shard_slug(),
timeline_id=%tl.timeline.timeline_id,
"timeline resident layers count: {}", info.resident_layers.len()
);
debug!(tenant_id=%tl.tenant_shard_id.tenant_id, shard_id=%tl.tenant_shard_id.shard_slug(), timeline_id=%tl.timeline_id, "timeline resident layers count: {}", info.resident_layers.len());
tenant_candidates.extend(info.resident_layers.into_iter());
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));

View File

@@ -1,94 +0,0 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use posthog_client_lite::{
FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
};
use tokio_util::sync::CancellationToken;
use utils::id::TenantId;
use crate::config::PageServerConf;
#[derive(Clone)]
pub struct FeatureResolver {
inner: Option<Arc<FeatureResolverBackgroundLoop>>,
}
impl FeatureResolver {
pub fn new_disabled() -> Self {
Self { inner: None }
}
pub fn spawn(
conf: &PageServerConf,
shutdown_pageserver: CancellationToken,
handle: &tokio::runtime::Handle,
) -> anyhow::Result<Self> {
// DO NOT block in this function: make it return as fast as possible to avoid startup delays.
if let Some(posthog_config) = &conf.posthog_config {
let inner = FeatureResolverBackgroundLoop::new(
PostHogClientConfig {
server_api_key: posthog_config.server_api_key.clone(),
client_api_key: posthog_config.client_api_key.clone(),
project_id: posthog_config.project_id.clone(),
private_api_url: posthog_config.private_api_url.clone(),
public_api_url: posthog_config.public_api_url.clone(),
},
shutdown_pageserver,
);
let inner = Arc::new(inner);
// TODO: make this configurable
inner.clone().spawn(handle, Duration::from_secs(60));
Ok(FeatureResolver { inner: Some(inner) })
} else {
Ok(FeatureResolver { inner: None })
}
}
/// Evaluate a multivariate feature flag. Currently, we do not support any properties.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_multivariate(
&self,
flag_key: &str,
tenant_id: TenantId,
) -> Result<String, PostHogEvaluationError> {
if let Some(inner) = &self.inner {
inner.feature_store().evaluate_multivariate(
flag_key,
&tenant_id.to_string(),
&HashMap::new(),
)
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),
))
}
}
/// Evaluate a boolean feature flag. Currently, we do not support any properties.
///
/// Returns `Ok(())` if the flag is evaluated to true, otherwise returns an error.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_boolean(
&self,
flag_key: &str,
tenant_id: TenantId,
) -> Result<(), PostHogEvaluationError> {
if let Some(inner) = &self.inner {
inner.feature_store().evaluate_boolean(
flag_key,
&tenant_id.to_string(),
&HashMap::new(),
)
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),
))
}
}
}

View File

@@ -353,33 +353,6 @@ paths:
"200":
description: OK
/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/mark_invisible:
parameters:
- name: tenant_shard_id
in: path
required: true
schema:
type: string
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
put:
requestBody:
content:
application/json:
schema:
type: object
properties:
is_visible:
type: boolean
default: false
responses:
"200":
description: OK
/v1/tenant/{tenant_shard_id}/location_config:
parameters:
- name: tenant_shard_id
@@ -653,8 +626,6 @@ paths:
format: hex
pg_version:
type: integer
read_only:
type: boolean
existing_initdb_timeline_id:
type: string
format: hex

View File

@@ -370,18 +370,6 @@ impl From<crate::tenant::secondary::SecondaryTenantError> for ApiError {
}
}
impl From<crate::tenant::FinalizeTimelineImportError> for ApiError {
fn from(err: crate::tenant::FinalizeTimelineImportError) -> ApiError {
use crate::tenant::FinalizeTimelineImportError::*;
match err {
ImportTaskStillRunning => {
ApiError::ResourceUnavailable("Import task still running".into())
}
ShuttingDown => ApiError::ShuttingDown,
}
}
}
// Helper function to construct a TimelineInfo struct for a timeline
async fn build_timeline_info(
timeline: &Arc<Timeline>,
@@ -584,7 +572,6 @@ async fn timeline_create_handler(
TimelineCreateRequestMode::Branch {
ancestor_timeline_id,
ancestor_start_lsn,
read_only: _,
pg_version: _,
} => tenant::CreateTimelineParams::Branch(tenant::CreateTimelineParamsBranch {
new_timeline_id,
@@ -3545,7 +3532,10 @@ async fn activate_post_import_handler(
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
tenant.finalize_importing_timeline(timeline_id).await?;
tenant
.finalize_importing_timeline(timeline_id)
.await
.map_err(ApiError::InternalServerError)?;
match tenant.get_timeline(timeline_id, false) {
Ok(_timeline) => {

View File

@@ -10,7 +10,6 @@ pub mod context;
pub mod controller_upcall_client;
pub mod deletion_queue;
pub mod disk_usage_eviction_task;
pub mod feature_resolver;
pub mod http;
pub mod import_datadir;
pub mod l0_flush;

View File

@@ -2234,10 +2234,8 @@ impl BasebackupQueryTimeOngoingRecording<'_> {
// If you want to change categorize of a specific error, also change it in `log_query_error`.
let metric = match res {
Ok(_) => &self.parent.ok,
Err(QueryError::Shutdown) | Err(QueryError::Reconnect) => {
// Do not observe ok/err for shutdown/reconnect.
// Reconnect error might be raised when the operation is waiting for LSN and the tenant shutdown interrupts
// the operation. A reconnect error will be issued and the client will retry.
Err(QueryError::Shutdown) => {
// Do not observe ok/err for shutdown
return;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error)))

View File

@@ -43,14 +43,12 @@ use strum_macros::IntoStaticStr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tonic::service::Interceptor as _;
use tracing::*;
use utils::auth::{Claims, Scope, SwappableJwtAuth};
use utils::failpoint_support;
use utils::id::{TenantId, TenantTimelineId, TimelineId};
use utils::id::{TenantId, TimelineId};
use utils::logging::log_slow;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use utils::simple_rcu::RcuReadGuard;
use utils::sync::gate::{Gate, GateGuard};
use utils::sync::spsc_fold;
@@ -202,9 +200,9 @@ pub fn spawn_grpc(
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service.
let page_service_handler = PageServerHandler::new(
let page_service = proto::PageServiceServer::new(PageServerHandler::new(
tenant_manager,
auth.clone(),
auth,
PageServicePipeliningConfig::Serial, // TODO: unused with gRPC
conf.get_vectored_concurrent_io,
ConnectionPerfSpanFields::default(),
@@ -212,18 +210,7 @@ pub fn spawn_grpc(
ctx,
cancel.clone(),
gate.enter().expect("just created"),
);
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let interceptors = move |mut req: tonic::Request<()>| {
req = tenant_interceptor.call(req)?;
req = auth_interceptor.call(req)?;
Ok(req)
};
let page_service =
proto::PageServiceServer::with_interceptor(page_service_handler, interceptors);
));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
@@ -769,9 +756,6 @@ struct BatchedGetPageRequest {
timer: SmgrOpTimer,
lsn_range: LsnRange,
ctx: RequestContext,
// If the request is perf enabled, this contains a context
// with a perf span tracking the time spent waiting for the executor.
batch_wait_ctx: Option<RequestContext>,
}
#[cfg(feature = "testing")]
@@ -784,7 +768,6 @@ struct BatchedTestRequest {
/// so that we don't keep the [`Timeline::gate`] open while the batch
/// is being built up inside the [`spsc_fold`] (pagestream pipelining).
#[derive(IntoStaticStr)]
#[allow(clippy::large_enum_variant)]
enum BatchedFeMessage {
Exists {
span: Span,
@@ -1302,22 +1285,6 @@ impl PageServerHandler {
}
};
let batch_wait_ctx = if ctx.has_perf_span() {
Some(
RequestContextBuilder::from(&ctx)
.perf_span(|crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"WAIT_EXECUTOR",
)
})
.attached_child(),
)
} else {
None
};
BatchedFeMessage::GetPage {
span,
shard: shard.downgrade(),
@@ -1329,7 +1296,6 @@ impl PageServerHandler {
request_lsn: req.hdr.request_lsn
},
ctx,
batch_wait_ctx,
}],
// The executor grabs the batch when it becomes idle.
// Hence, [`GetPageBatchBreakReason::ExecutorSteal`] is the
@@ -1485,7 +1451,7 @@ impl PageServerHandler {
let mut flush_timers = Vec::with_capacity(handler_results.len());
for handler_result in &mut handler_results {
let flush_timer = match handler_result {
Ok((_response, timer, _ctx)) => Some(
Ok((_, timer)) => Some(
timer
.observe_execution_end(flushing_start_time)
.expect("we are the first caller"),
@@ -1505,7 +1471,7 @@ impl PageServerHandler {
// Some handler errors cause exit from pagestream protocol.
// Other handler errors are sent back as an error message and we stay in pagestream protocol.
for (handler_result, flushing_timer) in handler_results.into_iter().zip(flush_timers) {
let (response_msg, ctx) = match handler_result {
let response_msg = match handler_result {
Err(e) => match &e.err {
PageStreamError::Shutdown => {
// If we fail to fulfil a request during shutdown, which may be _because_ of
@@ -1530,30 +1496,15 @@ impl PageServerHandler {
error!("error reading relation or page version: {full:#}")
});
(
PagestreamBeMessage::Error(PagestreamErrorResponse {
req: e.req,
message: e.err.to_string(),
}),
None,
)
PagestreamBeMessage::Error(PagestreamErrorResponse {
req: e.req,
message: e.err.to_string(),
})
}
},
Ok((response_msg, _op_timer_already_observed, ctx)) => (response_msg, Some(ctx)),
Ok((response_msg, _op_timer_already_observed)) => response_msg,
};
let ctx = ctx.map(|req_ctx| {
RequestContextBuilder::from(&req_ctx)
.perf_span(|crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"FLUSH_RESPONSE",
)
})
.attached_child()
});
//
// marshal & transmit response message
//
@@ -1576,17 +1527,6 @@ impl PageServerHandler {
)),
None => futures::future::Either::Right(flush_fut),
};
let flush_fut = if let Some(req_ctx) = ctx.as_ref() {
futures::future::Either::Left(
flush_fut.maybe_perf_instrument(req_ctx, |current_perf_span| {
current_perf_span.clone()
}),
)
} else {
futures::future::Either::Right(flush_fut)
};
// do it while respecting cancellation
let _: () = async move {
tokio::select! {
@@ -1616,7 +1556,7 @@ impl PageServerHandler {
ctx: &RequestContext,
) -> Result<
(
Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>,
Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>>,
Span,
),
QueryError,
@@ -1643,7 +1583,7 @@ impl PageServerHandler {
self.handle_get_rel_exists_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -1662,7 +1602,7 @@ impl PageServerHandler {
self.handle_get_nblocks_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -1709,7 +1649,7 @@ impl PageServerHandler {
self.handle_db_size_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -1728,7 +1668,7 @@ impl PageServerHandler {
self.handle_get_slru_segment_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -2080,25 +2020,12 @@ impl PageServerHandler {
return Ok(());
}
};
let mut batch = match batch {
let batch = match batch {
Ok(batch) => batch,
Err(e) => {
return Err(e);
}
};
if let BatchedFeMessage::GetPage {
pages,
span: _,
shard: _,
batch_break_reason: _,
} = &mut batch
{
for req in pages {
req.batch_wait_ctx.take();
}
}
self.pagestream_handle_batched_message(
pgb_writer,
batch,
@@ -2411,8 +2338,7 @@ impl PageServerHandler {
io_concurrency: IoConcurrency,
batch_break_reason: GetPageBatchBreakReason,
ctx: &RequestContext,
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>
{
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>> {
debug_assert_current_span_has_tenant_and_timeline_id();
timeline
@@ -2519,7 +2445,6 @@ impl PageServerHandler {
page,
}),
req.timer,
req.ctx,
)
})
.map_err(|e| BatchedPageStreamError {
@@ -2564,8 +2489,7 @@ impl PageServerHandler {
timeline: &Timeline,
requests: Vec<BatchedTestRequest>,
_ctx: &RequestContext,
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>
{
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>> {
// real requests would do something with the timeline
let mut results = Vec::with_capacity(requests.len());
for _req in requests.iter() {
@@ -2592,10 +2516,6 @@ impl PageServerHandler {
req: req.req.clone(),
}),
req.timer,
RequestContext::new(
TaskKind::PageRequestHandler,
DownloadBehavior::Warn,
),
)
})
.map_err(|e| BatchedPageStreamError {
@@ -3370,104 +3290,6 @@ impl From<GetActiveTenantError> for QueryError {
}
}
/// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type
/// TenantTimelineId and ShardIndex.
///
/// TODO: consider looking up the timeline handle here and storing it.
#[derive(Clone)]
struct TenantMetadataInterceptor;
impl tonic::service::Interceptor for TenantMetadataInterceptor {
fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
// Decode the tenant ID.
let tenant_id = req
.metadata()
.get("neon-tenant-id")
.ok_or_else(|| tonic::Status::invalid_argument("missing neon-tenant-id"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?;
let tenant_id = TenantId::from_str(tenant_id)
.map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?;
// Decode the timeline ID.
let timeline_id = req
.metadata()
.get("neon-timeline-id")
.ok_or_else(|| tonic::Status::invalid_argument("missing neon-timeline-id"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?;
let timeline_id = TimelineId::from_str(timeline_id)
.map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?;
// Decode the shard ID.
let shard_index = req
.metadata()
.get("neon-shard-id")
.ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?;
let shard_index = ShardIndex::from_str(shard_index)
.map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?;
// Stash them in the request.
let extensions = req.extensions_mut();
extensions.insert(TenantTimelineId::new(tenant_id, timeline_id));
extensions.insert(shard_index);
Ok(req)
}
}
/// Authenticates gRPC page service requests. Must run after TenantMetadataInterceptor.
#[derive(Clone)]
struct TenantAuthInterceptor {
auth: Option<Arc<SwappableJwtAuth>>,
}
impl TenantAuthInterceptor {
fn new(auth: Option<Arc<SwappableJwtAuth>>) -> Self {
Self { auth }
}
}
impl tonic::service::Interceptor for TenantAuthInterceptor {
fn call(&mut self, req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
// Do nothing if auth is disabled.
let Some(auth) = self.auth.as_ref() else {
return Ok(req);
};
// Fetch the tenant ID that's been set by TenantMetadataInterceptor.
let ttid = req
.extensions()
.get::<TenantTimelineId>()
.expect("TenantMetadataInterceptor must run before TenantAuthInterceptor");
// Fetch and decode the JWT token.
let jwt = req
.metadata()
.get("authorization")
.ok_or_else(|| tonic::Status::unauthenticated("no authorization header"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid authorization header"))?
.strip_prefix("Bearer ")
.ok_or_else(|| tonic::Status::invalid_argument("invalid authorization header"))?
.trim();
let jwtdata: TokenData<Claims> = auth
.decode(jwt)
.map_err(|err| tonic::Status::invalid_argument(format!("invalid JWT token: {err}")))?;
let claims = jwtdata.claims;
// Check if the token is valid for this tenant.
check_permission(&claims, Some(ttid.tenant_id))
.map_err(|err| tonic::Status::permission_denied(err.to_string()))?;
// TODO: consider stashing the claims in the request extensions, if needed.
Ok(req)
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum GetActiveTimelineError {
#[error(transparent)]

View File

@@ -84,7 +84,6 @@ use crate::context;
use crate::context::RequestContextBuilder;
use crate::context::{DownloadBehavior, RequestContext};
use crate::deletion_queue::{DeletionQueueClient, DeletionQueueError};
use crate::feature_resolver::FeatureResolver;
use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
@@ -160,7 +159,6 @@ pub struct TenantSharedResources {
pub deletion_queue_client: DeletionQueueClient,
pub l0_flush_global_state: L0FlushGlobalState,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub feature_resolver: FeatureResolver,
}
/// A [`TenantShard`] is really an _attached_ tenant. The configuration
@@ -300,7 +298,7 @@ pub struct TenantShard {
/// as in progress.
/// * Imported timelines are removed when the storage controller calls the post timeline
/// import activation endpoint.
timelines_importing: std::sync::Mutex<HashMap<TimelineId, Arc<ImportingTimeline>>>,
timelines_importing: std::sync::Mutex<HashMap<TimelineId, ImportingTimeline>>,
/// The last tenant manifest known to be in remote storage. None if the manifest has not yet
/// been either downloaded or uploaded. Always Some after tenant attach.
@@ -382,8 +380,6 @@ pub struct TenantShard {
pub(crate) gc_block: gc_block::GcBlock,
l0_flush_global_state: L0FlushGlobalState,
feature_resolver: FeatureResolver,
}
impl std::fmt::Debug for TenantShard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -672,7 +668,6 @@ pub enum MaybeOffloaded {
pub enum TimelineOrOffloaded {
Timeline(Arc<Timeline>),
Offloaded(Arc<OffloadedTimeline>),
Importing(Arc<ImportingTimeline>),
}
impl TimelineOrOffloaded {
@@ -684,9 +679,6 @@ impl TimelineOrOffloaded {
TimelineOrOffloaded::Offloaded(offloaded) => {
TimelineOrOffloadedArcRef::Offloaded(offloaded)
}
TimelineOrOffloaded::Importing(importing) => {
TimelineOrOffloadedArcRef::Importing(importing)
}
}
}
pub fn tenant_shard_id(&self) -> TenantShardId {
@@ -699,16 +691,12 @@ impl TimelineOrOffloaded {
match self {
TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress,
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress,
TimelineOrOffloaded::Importing(importing) => &importing.delete_progress,
}
}
fn maybe_remote_client(&self) -> Option<Arc<RemoteTimelineClient>> {
match self {
TimelineOrOffloaded::Timeline(timeline) => Some(timeline.remote_client.clone()),
TimelineOrOffloaded::Offloaded(_offloaded) => None,
TimelineOrOffloaded::Importing(importing) => {
Some(importing.timeline.remote_client.clone())
}
}
}
}
@@ -716,7 +704,6 @@ impl TimelineOrOffloaded {
pub enum TimelineOrOffloadedArcRef<'a> {
Timeline(&'a Arc<Timeline>),
Offloaded(&'a Arc<OffloadedTimeline>),
Importing(&'a Arc<ImportingTimeline>),
}
impl TimelineOrOffloadedArcRef<'_> {
@@ -724,14 +711,12 @@ impl TimelineOrOffloadedArcRef<'_> {
match self {
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.tenant_shard_id,
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.tenant_shard_id,
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.tenant_shard_id,
}
}
pub fn timeline_id(&self) -> TimelineId {
match self {
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.timeline_id,
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.timeline_id,
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.timeline_id,
}
}
}
@@ -748,12 +733,6 @@ impl<'a> From<&'a Arc<OffloadedTimeline>> for TimelineOrOffloadedArcRef<'a> {
}
}
impl<'a> From<&'a Arc<ImportingTimeline>> for TimelineOrOffloadedArcRef<'a> {
fn from(timeline: &'a Arc<ImportingTimeline>) -> Self {
Self::Importing(timeline)
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum GetTimelineError {
#[error("Timeline is shutting down")]
@@ -881,14 +860,6 @@ impl Debug for SetStoppingError {
}
}
#[derive(thiserror::Error, Debug)]
pub(crate) enum FinalizeTimelineImportError {
#[error("Import task not done yet")]
ImportTaskStillRunning,
#[error("Shutting down")]
ShuttingDown,
}
/// Arguments to [`TenantShard::create_timeline`].
///
/// Not usable as an idempotency key for timeline creation because if [`CreateTimelineParamsBranch::ancestor_start_lsn`]
@@ -1175,20 +1146,10 @@ impl TenantShard {
ctx,
)?;
let disk_consistent_lsn = timeline.get_disk_consistent_lsn();
if !disk_consistent_lsn.is_valid() {
// As opposed to normal timelines which get initialised with a disk consitent LSN
// via initdb, imported timelines start from 0. If the import task stops before
// it advances disk consitent LSN, allow it to resume.
let in_progress_import = import_pgdata
.as_ref()
.map(|import| !import.is_done())
.unwrap_or(false);
if !in_progress_import {
anyhow::bail!("Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn");
}
}
anyhow::ensure!(
disk_consistent_lsn.is_valid(),
"Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn"
);
assert_eq!(
disk_consistent_lsn,
metadata.disk_consistent_lsn(),
@@ -1282,25 +1243,20 @@ impl TenantShard {
}
}
if disk_consistent_lsn.is_valid() {
// Sanity check: a timeline should have some content.
// Exception: importing timelines might not yet have any
anyhow::ensure!(
ancestor.is_some()
|| timeline
.layers
.read()
.await
.layer_map()
.expect(
"currently loading, layer manager cannot be shutdown already"
)
.iter_historic_layers()
.next()
.is_some(),
"Timeline has no ancestor and no layer files"
);
}
// Sanity check: a timeline should have some content.
anyhow::ensure!(
ancestor.is_some()
|| timeline
.layers
.read()
.await
.layer_map()
.expect("currently loading, layer manager cannot be shutdown already")
.iter_historic_layers()
.next()
.is_some(),
"Timeline has no ancestor and no layer files"
);
Ok(TimelineInitAndSyncResult::ReadyToActivate)
}
@@ -1336,7 +1292,6 @@ impl TenantShard {
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
} = resources;
let attach_mode = attached_conf.location.attach_mode;
@@ -1353,7 +1308,6 @@ impl TenantShard {
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
));
// The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if
@@ -1806,25 +1760,20 @@ impl TenantShard {
},
) => {
let timeline_id = timeline.timeline_id;
let import_task_gate = Gate::default();
let import_task_guard = import_task_gate.enter().unwrap();
let import_task_handle =
tokio::task::spawn(self.clone().create_timeline_import_pgdata_task(
timeline.clone(),
import_pgdata,
guard,
import_task_guard,
ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
));
let prev = self.timelines_importing.lock().unwrap().insert(
timeline_id,
Arc::new(ImportingTimeline {
ImportingTimeline {
timeline: timeline.clone(),
import_task_handle,
import_task_gate,
delete_progress: TimelineDeleteProgress::default(),
}),
},
);
assert!(prev.is_none());
@@ -2442,17 +2391,6 @@ impl TenantShard {
.collect()
}
/// Lists timelines the tenant contains.
/// It's up to callers to omit certain timelines that are not considered ready for use.
pub fn list_importing_timelines(&self) -> Vec<Arc<ImportingTimeline>> {
self.timelines_importing
.lock()
.unwrap()
.values()
.map(Arc::clone)
.collect()
}
/// Lists timelines the tenant manages, including offloaded ones.
///
/// It's up to callers to omit certain timelines that are not considered ready for use.
@@ -2886,25 +2824,19 @@ impl TenantShard {
let (timeline, timeline_create_guard) = uninit_timeline.finish_creation_myself();
let import_task_gate = Gate::default();
let import_task_guard = import_task_gate.enter().unwrap();
let import_task_handle = tokio::spawn(self.clone().create_timeline_import_pgdata_task(
timeline.clone(),
index_part,
timeline_create_guard,
import_task_guard,
timeline_ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
));
let prev = self.timelines_importing.lock().unwrap().insert(
timeline.timeline_id,
Arc::new(ImportingTimeline {
ImportingTimeline {
timeline: timeline.clone(),
import_task_handle,
import_task_gate,
delete_progress: TimelineDeleteProgress::default(),
}),
},
);
// Idempotency is enforced higher up the stack
@@ -2922,13 +2854,13 @@ impl TenantShard {
pub(crate) async fn finalize_importing_timeline(
&self,
timeline_id: TimelineId,
) -> Result<(), FinalizeTimelineImportError> {
) -> anyhow::Result<()> {
let timeline = {
let locked = self.timelines_importing.lock().unwrap();
match locked.get(&timeline_id) {
Some(importing_timeline) => {
if !importing_timeline.import_task_handle.is_finished() {
return Err(FinalizeTimelineImportError::ImportTaskStillRunning);
return Err(anyhow::anyhow!("Import task not done yet"));
}
importing_timeline.timeline.clone()
@@ -2941,13 +2873,8 @@ impl TenantShard {
timeline
.remote_client
.schedule_index_upload_for_import_pgdata_finalize()
.map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?;
timeline
.remote_client
.wait_completion()
.await
.map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?;
.schedule_index_upload_for_import_pgdata_finalize()?;
timeline.remote_client.wait_completion().await?;
self.timelines_importing
.lock()
@@ -2963,7 +2890,6 @@ impl TenantShard {
timeline: Arc<Timeline>,
index_part: import_pgdata::index_part_format::Root,
timeline_create_guard: TimelineCreateGuard,
_import_task_guard: GateGuard,
ctx: RequestContext,
) {
debug_assert_current_span_has_tenant_and_timeline_id();
@@ -3209,18 +3135,11 @@ impl TenantShard {
.or_insert_with(|| Arc::new(GcCompactionQueue::new()))
.clone()
};
let gc_compaction_strategy = self
.feature_resolver
.evaluate_multivariate("gc-comapction-strategy", self.tenant_shard_id.tenant_id)
.ok();
let span = if let Some(gc_compaction_strategy) = gc_compaction_strategy {
info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id, strategy = %gc_compaction_strategy)
} else {
info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id)
};
outcome = queue
.iteration(cancel, ctx, &self.gc_block, &timeline)
.instrument(span)
.instrument(
info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id),
)
.await?;
}
@@ -3552,9 +3471,8 @@ impl TenantShard {
let mut timelines_importing = self.timelines_importing.lock().unwrap();
timelines_importing
.drain()
.for_each(|(timeline_id, importing_timeline)| {
let span = tracing::info_span!("importing_timeline_shutdown", %timeline_id);
js.spawn(async move { importing_timeline.shutdown().instrument(span).await });
.for_each(|(_timeline_id, importing_timeline)| {
importing_timeline.shutdown();
});
}
// test_long_timeline_create_then_tenant_delete is leaning on this message
@@ -3875,9 +3793,6 @@ impl TenantShard {
.build_timeline_client(offloaded.timeline_id, self.remote_storage.clone());
Arc::new(remote_client)
}
TimelineOrOffloadedArcRef::Importing(_) => {
unreachable!("Importing timelines are not included in the iterator")
}
};
// Shut down the timeline's remote client: this means that the indices we write
@@ -4332,7 +4247,6 @@ impl TenantShard {
deletion_queue_client: DeletionQueueClient,
l0_flush_global_state: L0FlushGlobalState,
basebackup_prepare_sender: BasebackupPrepareSender,
feature_resolver: FeatureResolver,
) -> TenantShard {
assert!(!attached_conf.location.generation.is_none());
@@ -4437,7 +4351,6 @@ impl TenantShard {
gc_block: Default::default(),
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
}
}
@@ -5087,14 +5000,6 @@ impl TenantShard {
info!("timeline already exists but is offloaded");
Err(CreateTimelineError::Conflict)
}
Err(TimelineExclusionError::AlreadyExists {
existing: TimelineOrOffloaded::Importing(_existing),
..
}) => {
// If there's a timeline already importing, then we would hit
// the [`TimelineExclusionError::AlreadyCreating`] branch above.
unreachable!("Importing timelines hold the creation guard")
}
Err(TimelineExclusionError::AlreadyExists {
existing: TimelineOrOffloaded::Timeline(existing),
arg,
@@ -5366,7 +5271,6 @@ impl TenantShard {
l0_compaction_trigger: self.l0_compaction_trigger.clone(),
l0_flush_global_state: self.l0_flush_global_state.clone(),
basebackup_prepare_sender: self.basebackup_prepare_sender.clone(),
feature_resolver: self.feature_resolver.clone(),
}
}
@@ -5969,7 +5873,6 @@ pub(crate) mod harness {
// TODO: ideally we should run all unit tests with both configs
L0FlushGlobalState::new(L0FlushConfig::default()),
basebackup_requst_sender,
FeatureResolver::new_disabled(),
));
let preload = tenant
@@ -8411,24 +8314,10 @@ mod tests {
}
tline.freeze_and_flush().await?;
// Force layers to L1
tline
.compact(
&cancel,
{
let mut flags = EnumSet::new();
flags.insert(CompactFlags::ForceL0Compaction);
flags
},
&ctx,
)
.await?;
if iter % 5 == 0 {
let scan_lsn = Lsn(lsn.0 + 1);
info!("scanning at {}", scan_lsn);
let (_, before_delta_file_accessed) =
scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone())
scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone())
.await?;
tline
.compact(
@@ -8437,14 +8326,13 @@ mod tests {
let mut flags = EnumSet::new();
flags.insert(CompactFlags::ForceImageLayerCreation);
flags.insert(CompactFlags::ForceRepartition);
flags.insert(CompactFlags::ForceL0Compaction);
flags
},
&ctx,
)
.await?;
let (_, after_delta_file_accessed) =
scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone())
scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone())
.await?;
assert!(
after_delta_file_accessed < before_delta_file_accessed,
@@ -8885,8 +8773,6 @@ mod tests {
let cancel = CancellationToken::new();
// Image layer creation happens on the disk_consistent_lsn so we need to force set it now.
tline.force_set_disk_consistent_lsn(Lsn(0x40));
tline
.compact(
&cancel,
@@ -8900,7 +8786,8 @@ mod tests {
)
.await
.unwrap();
// Image layers are created at repartition LSN
// Image layers are created at last_record_lsn
let images = tline
.inspect_image_layers(Lsn(0x40), &ctx, io_concurrency.clone())
.await

View File

@@ -1348,21 +1348,6 @@ impl RemoteTimelineClient {
Ok(())
}
pub(crate) fn schedule_unlinking_of_layers_from_index_part<I>(
self: &Arc<Self>,
names: I,
) -> Result<(), NotInitialized>
where
I: IntoIterator<Item = LayerName>,
{
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names);
Ok(())
}
/// Update the remote index file, removing the to-be-deleted files from the index,
/// allowing scheduling of actual deletions later.
fn schedule_unlinking_of_layers_from_index_part0<I>(

View File

@@ -103,7 +103,6 @@ use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
};
use crate::disk_usage_eviction_task::{DiskUsageEvictionInfo, EvictionCandidate, finite_f32};
use crate::feature_resolver::FeatureResolver;
use crate::keyspace::{KeyPartitioning, KeySpace};
use crate::l0_flush::{self, L0FlushGlobalState};
use crate::metrics::{
@@ -199,7 +198,6 @@ pub struct TimelineResources {
pub l0_compaction_trigger: Arc<Notify>,
pub l0_flush_global_state: l0_flush::L0FlushGlobalState,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub feature_resolver: FeatureResolver,
}
pub struct Timeline {
@@ -446,8 +444,6 @@ pub struct Timeline {
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_prepare_sender: BasebackupPrepareSender,
feature_resolver: FeatureResolver,
}
pub(crate) enum PreviousHeatmap {
@@ -3076,8 +3072,6 @@ impl Timeline {
wait_lsn_log_slow: tokio::sync::Semaphore::new(1),
basebackup_prepare_sender: resources.basebackup_prepare_sender,
feature_resolver: resources.feature_resolver,
};
result.repartition_threshold =
@@ -4912,7 +4906,6 @@ impl Timeline {
LastImageLayerCreationStatus::Initial,
false, // don't yield for L0, we're flushing L0
)
.instrument(info_span!("create_image_layers", mode = %ImageLayerCreationMode::Initial, partition_mode = "initial", lsn = %self.initdb_lsn))
.await?;
debug_assert!(
matches!(is_complete, LastImageLayerCreationStatus::Complete),
@@ -5469,8 +5462,7 @@ impl Timeline {
/// Returns the image layers generated and an enum indicating whether the process is fully completed.
/// true = we have generate all image layers, false = we preempt the process for L0 compaction.
///
/// `partition_mode` is only for logging purpose and is not used anywhere in this function.
#[tracing::instrument(skip_all, fields(%lsn, %mode))]
async fn create_image_layers(
self: &Arc<Timeline>,
partitioning: &KeyPartitioning,

View File

@@ -206,8 +206,8 @@ pub struct GcCompactionQueue {
}
static CONCURRENT_GC_COMPACTION_TASKS: Lazy<Arc<Semaphore>> = Lazy::new(|| {
// Only allow one timeline on one pageserver to run gc compaction at a time.
Arc::new(Semaphore::new(1))
// Only allow two timelines on one pageserver to run gc compaction at a time.
Arc::new(Semaphore::new(2))
});
impl GcCompactionQueue {
@@ -1278,55 +1278,11 @@ impl Timeline {
}
let gc_cutoff = *self.applied_gc_cutoff_lsn.read();
let l0_l1_boundary_lsn = {
// We do the repartition on the L0-L1 boundary. All data below the boundary
// are compacted by L0 with low read amplification, thus making the `repartition`
// function run fast.
let guard = self.layers.read().await;
guard
.all_persistent_layers()
.iter()
.map(|x| {
// Use the end LSN of delta layers OR the start LSN of image layers.
if x.is_delta {
x.lsn_range.end
} else {
x.lsn_range.start
}
})
.max()
};
let (partition_mode, partition_lsn) = if cfg!(test)
|| cfg!(feature = "testing")
|| self
.feature_resolver
.evaluate_boolean("image-compaction-boundary", self.tenant_shard_id.tenant_id)
.is_ok()
{
let last_repartition_lsn = self.partitioning.read().1;
let lsn = match l0_l1_boundary_lsn {
Some(boundary) => gc_cutoff
.max(boundary)
.max(last_repartition_lsn)
.max(self.initdb_lsn)
.max(self.ancestor_lsn),
None => self.get_last_record_lsn(),
};
if lsn <= self.initdb_lsn || lsn <= self.ancestor_lsn {
// Do not attempt to create image layers below the initdb or ancestor LSN -- no data below it
("l0_l1_boundary", self.get_last_record_lsn())
} else {
("l0_l1_boundary", lsn)
}
} else {
("latest_record", self.get_last_record_lsn())
};
// 2. Repartition and create image layers if necessary
match self
.repartition(
partition_lsn,
self.get_last_record_lsn(),
self.get_compaction_target_size(),
options.flags,
ctx,
@@ -1345,19 +1301,18 @@ impl Timeline {
.extend(sparse_partitioning.into_dense().parts);
// 3. Create new image layers for partitions that have been modified "enough".
let mode = if options
.flags
.contains(CompactFlags::ForceImageLayerCreation)
{
ImageLayerCreationMode::Force
} else {
ImageLayerCreationMode::Try
};
let (image_layers, outcome) = self
.create_image_layers(
&partitioning,
lsn,
mode,
if options
.flags
.contains(CompactFlags::ForceImageLayerCreation)
{
ImageLayerCreationMode::Force
} else {
ImageLayerCreationMode::Try
},
&image_ctx,
self.last_image_layer_creation_status
.load()
@@ -1365,7 +1320,6 @@ impl Timeline {
.clone(),
options.flags.contains(CompactFlags::YieldForL0),
)
.instrument(info_span!("create_image_layers", mode = %mode, partition_mode = %partition_mode, lsn = %lsn))
.await
.inspect_err(|err| {
if let CreateImageLayersError::GetVectoredError(
@@ -1390,8 +1344,7 @@ impl Timeline {
}
Ok(_) => {
// This happens very frequently so we don't want to log it.
debug!("skipping repartitioning due to image compaction LSN being below GC cutoff");
info!("skipping repartitioning due to image compaction LSN being below GC cutoff");
}
// Suppress errors when cancelled.

View File

@@ -121,7 +121,6 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
// This observes the locking order between timelines and timelines_offloaded
let mut timelines = tenant.timelines.lock().unwrap();
let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
let mut timelines_importing = tenant.timelines_importing.lock().unwrap();
let offloaded_children_exist = timelines_offloaded
.iter()
.any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id()));
@@ -151,12 +150,8 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
.expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map");
offloaded_timeline.delete_from_ancestor_with_timelines(&timelines);
}
TimelineOrOffloaded::Importing(importing) => {
timelines_importing.remove(&importing.timeline.timeline_id);
}
}
drop(timelines_importing);
drop(timelines_offloaded);
drop(timelines);
@@ -208,17 +203,8 @@ impl DeleteTimelineFlow {
guard.mark_in_progress()?;
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
// TODO(vlad): shut down imported timeline here
match &timeline {
TimelineOrOffloaded::Timeline(timeline) => {
timeline.shutdown(super::ShutdownMode::Hard).await;
}
TimelineOrOffloaded::Importing(importing) => {
importing.shutdown().await;
}
TimelineOrOffloaded::Offloaded(_offloaded) => {
// Nothing to shut down in this case
}
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
timeline.shutdown(super::ShutdownMode::Hard).await;
}
tenant.gc_block.before_delete(&timeline.timeline_id());
@@ -403,18 +389,10 @@ impl DeleteTimelineFlow {
Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))?
});
match timeline {
TimelineOrOffloaded::Timeline(timeline) => {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
}
TimelineOrOffloaded::Importing(importing) => {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &importing.timeline)
.await;
}
TimelineOrOffloaded::Offloaded(_offloaded) => {
// Offloaded timelines have no local state
// TODO: once we persist offloaded information, delete the timeline from there, too
}
// Offloaded timelines have no local state
// TODO: once we persist offloaded information, delete the timeline from there, too
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
}
fail::fail_point!("timeline-delete-after-rm", |_| {
@@ -473,16 +451,12 @@ pub(super) fn make_timeline_delete_guard(
// For more context see this discussion: `https://github.com/neondatabase/neon/pull/4552#discussion_r1253437346`
let timelines = tenant.timelines.lock().unwrap();
let timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
let timelines_importing = tenant.timelines_importing.lock().unwrap();
let timeline = match timelines.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)),
None => match timelines_offloaded.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)),
None => match timelines_importing.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Importing(Arc::clone(t)),
None => return Err(DeleteTimelineError::NotFound),
},
None => return Err(DeleteTimelineError::NotFound),
},
};

View File

@@ -8,10 +8,8 @@ use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::info;
use utils::lsn::Lsn;
use utils::pausable_failpoint;
use utils::sync::gate::Gate;
use super::{Timeline, TimelineDeleteProgress};
use super::Timeline;
use crate::context::RequestContext;
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
use crate::tenant::metadata::TimelineMetadata;
@@ -21,25 +19,14 @@ mod importbucket_client;
mod importbucket_format;
pub(crate) mod index_part_format;
pub struct ImportingTimeline {
pub(crate) struct ImportingTimeline {
pub import_task_handle: JoinHandle<()>,
pub import_task_gate: Gate,
pub timeline: Arc<Timeline>,
pub delete_progress: TimelineDeleteProgress,
}
impl std::fmt::Debug for ImportingTimeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ImportingTimeline<{}>", self.timeline.timeline_id)
}
}
impl ImportingTimeline {
pub async fn shutdown(&self) {
pub(crate) fn shutdown(self) {
self.import_task_handle.abort();
self.import_task_gate.close().await;
self.timeline.remote_client.shutdown().await;
}
}
@@ -106,13 +93,6 @@ pub async fn doit(
);
}
timeline
.remote_client
.schedule_index_upload_for_file_changes()?;
timeline.remote_client.wait_completion().await?;
pausable_failpoint!("import-timeline-pre-success-notify-pausable");
// Communicate that shard is done.
// Ensure at-least-once delivery of the upcall to storage controller
// before we mark the task as done and never come here again.

View File

@@ -30,7 +30,6 @@
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::ops::Range;
use std::sync::Arc;
@@ -101,24 +100,8 @@ async fn run_v1(
tasks: Vec::default(),
};
// Use the job size limit encoded in the progress if we are resuming an import.
// This ensures that imports have stable plans even if the pageserver config changes.
let import_config = {
match &import_progress {
Some(progress) => {
let base = &timeline.conf.timeline_import_config;
TimelineImportConfig {
import_job_soft_size_limit: NonZeroUsize::new(progress.job_soft_size_limit)
.unwrap(),
import_job_concurrency: base.import_job_concurrency,
import_job_checkpoint_threshold: base.import_job_checkpoint_threshold,
}
}
None => timeline.conf.timeline_import_config.clone(),
}
};
let plan = planner.plan(&import_config).await?;
let import_config = &timeline.conf.timeline_import_config;
let plan = planner.plan(import_config).await?;
// Hash the plan and compare with the hash of the plan we got back from the storage controller.
// If the two match, it means that the planning stage had the same output.
@@ -130,20 +113,20 @@ async fn run_v1(
let plan_hash = hasher.finish();
if let Some(progress) = &import_progress {
if plan_hash != progress.import_plan_hash {
anyhow::bail!("Import plan does not match storcon metadata");
}
// Handle collisions on jobs of unequal length
if progress.jobs != plan.jobs.len() {
anyhow::bail!("Import plan job length does not match storcon metadata")
}
if plan_hash != progress.import_plan_hash {
anyhow::bail!("Import plan does not match storcon metadata");
}
}
pausable_failpoint!("import-timeline-pre-execute-pausable");
let start_from_job_idx = import_progress.map(|progress| progress.completed);
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
plan.execute(timeline, start_from_job_idx, plan_hash, import_config, ctx)
.await
}
@@ -235,19 +218,6 @@ impl Planner {
checkpoint_buf,
)));
// Sort the tasks by the key ranges they handle.
// The plan being generated here needs to be stable across invocations
// of this method.
self.tasks.sort_by_key(|task| match task {
AnyImportTask::SingleKey(key) => (key.key, key.key.next()),
AnyImportTask::RelBlocks(rel_blocks) => {
(rel_blocks.key_range.start, rel_blocks.key_range.end)
}
AnyImportTask::SlruBlocks(slru_blocks) => {
(slru_blocks.key_range.start, slru_blocks.key_range.end)
}
});
// Assigns parts of key space to later parallel jobs
let mut last_end_key = Key::MIN;
let mut current_chunk = Vec::new();
@@ -456,8 +426,6 @@ impl Plan {
}));
},
maybe_complete_job_idx = work.next() => {
pausable_failpoint!("import-task-complete-pausable");
match maybe_complete_job_idx {
Some(Ok((job_idx, res))) => {
assert!(last_completed_job_idx.checked_add(1).unwrap() == job_idx);
@@ -470,12 +438,8 @@ impl Plan {
jobs: jobs_in_plan,
completed: last_completed_job_idx,
import_plan_hash,
job_soft_size_limit: import_config.import_job_soft_size_limit.into(),
};
timeline.remote_client.schedule_index_upload_for_file_changes()?;
timeline.remote_client.wait_completion().await?;
storcon_client.put_timeline_import_status(
timeline.tenant_shard_id,
timeline.timeline_id,
@@ -676,11 +640,7 @@ impl Hash for ImportSingleKeyTask {
let ImportSingleKeyTask { key, buf } = self;
key.hash(state);
// The key value might not have a stable binary representation.
// For instance, the db directory uses an unstable hash-map.
// To work around this we are a bit lax here and only hash the
// size of the buffer which must be consistent.
buf.len().hash(state);
buf.hash(state);
}
}
@@ -955,7 +915,7 @@ impl ChunkProcessingJob {
let guard = timeline.layers.read().await;
let existing_layer = guard.try_get_from_key(&desc.key());
if let Some(layer) = existing_layer {
if layer.metadata().generation == timeline.generation {
if layer.metadata().generation != timeline.generation {
return Err(anyhow::anyhow!(
"Import attempted to rewrite layer file in the same generation: {}",
layer.local_path()
@@ -982,15 +942,6 @@ impl ChunkProcessingJob {
.cloned();
match existing_layer {
Some(existing) => {
// Unlink the remote layer from the index without scheduling its deletion.
// When `existing_layer` drops [`LayerInner::drop`] will schedule its deletion from
// remote storage, but that assumes that the layer was unlinked from the index first.
timeline
.remote_client
.schedule_unlinking_of_layers_from_index_part(std::iter::once(
existing.layer_desc().layer_name(),
))?;
guard.open_mut()?.rewrite_layers(
&[(existing.clone(), resident_layer.clone())],
&[],

View File

@@ -155,9 +155,8 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
int written = 0;
written = snprintf((char *) &sk->conninfo, MAXCONNINFO,
"%s host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'",
wp->config->safekeeper_conninfo_options, sk->host, sk->port,
wp->config->neon_timeline, wp->config->neon_tenant);
"host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'",
sk->host, sk->port, wp->config->neon_timeline, wp->config->neon_tenant);
if (written > MAXCONNINFO || written < 0)
wp_log(FATAL, "could not create connection string for safekeeper %s:%s", sk->host, sk->port);
}

View File

@@ -714,9 +714,6 @@ typedef struct WalProposerConfig
*/
char *safekeepers_list;
/* libpq connection info options. */
char *safekeeper_conninfo_options;
/*
* WalProposer reconnects to offline safekeepers once in this interval.
* Time is in milliseconds.

View File

@@ -64,7 +64,6 @@ char *wal_acceptors_list = "";
int wal_acceptor_reconnect_timeout = 1000;
int wal_acceptor_connection_timeout = 10000;
int safekeeper_proto_version = 3;
char *safekeeper_conninfo_options = "";
/* Set to true in the walproposer bgw. */
static bool am_walproposer;
@@ -120,7 +119,6 @@ init_walprop_config(bool syncSafekeepers)
walprop_config.neon_timeline = neon_timeline;
/* WalProposerCreate scribbles directly on it, so pstrdup */
walprop_config.safekeepers_list = pstrdup(wal_acceptors_list);
walprop_config.safekeeper_conninfo_options = pstrdup(safekeeper_conninfo_options);
walprop_config.safekeeper_reconnect_timeout = wal_acceptor_reconnect_timeout;
walprop_config.safekeeper_connection_timeout = wal_acceptor_connection_timeout;
walprop_config.wal_segment_size = wal_segment_size;
@@ -205,16 +203,6 @@ nwp_register_gucs(void)
* GUC_LIST_QUOTE */
NULL, assign_neon_safekeepers, NULL);
DefineCustomStringVariable(
"neon.safekeeper_conninfo_options",
"libpq keyword parameters and values to apply to safekeeper connections",
NULL,
&safekeeper_conninfo_options,
"",
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
DefineCustomIntVariable(
"neon.safekeeper_reconnect_timeout",
"Walproposer reconnects to offline safekeepers once in this interval.",

View File

@@ -17,7 +17,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::auth::{
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
@@ -135,6 +137,16 @@ impl<'a, T> Backend<'a, T> {
}
}
}
impl<'a, T, E> Backend<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
}
}
}
pub(crate) struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
@@ -272,7 +284,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -289,12 +301,15 @@ async fn auth_quirks(
debug!("fetching authentication info and allowlists");
// check allowed list
if config.ip_allowlist_check_enabled {
let allowed_ips = if config.ip_allowlist_check_enabled {
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
allowed_ips
} else {
Cached::new_uncached(Arc::new(vec![]))
};
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
@@ -353,7 +368,7 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok(keys),
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
@@ -405,39 +420,53 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl ControlPlaneClient {
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::Local(_) => "local",
}
}
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
&self,
self,
ctx: &RequestContext,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
user_info: ComputeUserInfoMaybeEndpoint,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
let (credentials, ip_allowlist) = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
}
};
// TODO: replace with some metric
info!("user successfully authenticated");
Ok(credentials)
res
}
}
@@ -507,25 +536,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
}
}
pub struct ControlPlaneWakeCompute<'a> {
pub cplane: &'a ControlPlaneClient,
pub creds: ComputeCredentials,
}
#[async_trait::async_trait]
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
self.cplane.wake_compute(ctx, &self.creds.info).await
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&self.creds.keys
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unimplemented, clippy::unwrap_used)]
@@ -542,7 +552,6 @@ mod tests {
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio_util::task::TaskTracker;
use super::jwt::JwkCache;
use super::{AuthRateLimiter, auth_quirks};
@@ -693,7 +702,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let mut stream = PqStream::new(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -775,7 +784,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let mut stream = PqStream::new(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -829,7 +838,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let mut stream = PqStream::new(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -878,7 +887,7 @@ mod tests {
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");
assert_eq!(creds.0.info.endpoint, "my-endpoint");
handle.await.unwrap();
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::{Backend, ControlPlaneWakeCompute};
pub use backend::Backend;
mod credentials;
pub(crate) use credentials::{

View File

@@ -18,7 +18,6 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsConnector;
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, error, info};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
@@ -227,8 +226,7 @@ pub(super) async fn task_main(
let dest_suffix = Arc::clone(&dest_suffix);
let compute_tls_config = compute_tls_config.clone();
let tracker = connections.token();
tokio::spawn(
connections.spawn(
async move {
socket
.set_nodelay(true)
@@ -251,7 +249,6 @@ pub(super) async fn task_main(
compute_tls_config,
tls_server_end_point,
socket,
tracker,
)
.await
}
@@ -277,11 +274,10 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tracker: TaskTrackerToken,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
use pq_proto::FeStartupPacket::SslRequest;
@@ -295,7 +291,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf, tracker) = stream.into_inner();
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empty.
// However, you could imagine pipelining of postgres
@@ -306,16 +302,13 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok((
Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
},
tracker,
))
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
}
unexpected => {
info!(
@@ -336,10 +329,8 @@ async fn handle_client(
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
tracker: TaskTrackerToken,
) -> anyhow::Result<()> {
let (mut tls_stream, _tracker) =
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -323,7 +323,7 @@ impl CancellationHandler {
}
}
pub(crate) fn get_key(self: Arc<Self>) -> Session {
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
@@ -340,7 +340,7 @@ impl CancellationHandler {
Session {
key,
redis_key,
cancellation_handler: self,
cancellation_handler: Arc::clone(self),
}
}

View File

@@ -1,9 +1,8 @@
use std::sync::Arc;
use futures::TryFutureExt;
use futures::{FutureExt, TryFutureExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info};
use crate::auth::backend::ConsoleRedirectBackend;
@@ -15,8 +14,10 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::passthrough;
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
pub async fn task_main(
config: &'static ProxyConfig,
@@ -34,6 +35,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -47,11 +49,11 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let tracker = connections.token();
tokio::spawn(async move {
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
@@ -101,80 +103,99 @@ pub async fn task_main(
&config.region,
);
let span = ctx.span();
let mut slot = Some(ctx);
let res = handle_client(
config,
backend,
&mut slot,
&ctx,
cancellation_handler,
socket,
conn_gauge,
tracker,
cancellations,
)
.instrument(span)
.instrument(ctx.span())
.boxed()
.await;
match (slot, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx_slot: &mut Option<RequestContext>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let request_gauge = metrics.connection_requests.guard(protocol);
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let data = {
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
};
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let (mut stream, params) = match data {
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data, tracker) => {
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
tokio::spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx_slot.take().expect("context must be set");
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
let _tracker = tracker;
cancellation_handler_clone
.cancel_session(
cancel_key_data,
@@ -184,17 +205,15 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
backend.get_api(),
)
.await
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
.ok();
}
.instrument(cancel_span)
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(());
return Ok(None);
}
};
drop(pause);
let ctx = ctx_slot.as_ref().expect("context must be set");
ctx.set_db_options(params.clone());
let (node_info, user_info, _ip_allowlist) = match backend
@@ -209,13 +228,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
let mut node = connect_to_compute(
ctx,
TcpMechanism {
&TcpMechanism {
user_info,
params_compat: true,
params: &params,
locks: &config.connect_compute_locks,
},
node_info,
&node_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
@@ -233,22 +252,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf, tracker) = stream.into_inner();
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
ctx,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
Ok(())
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
}

View File

@@ -38,7 +38,7 @@ pub struct RequestContext(
/// I would typically use a RefCell but that would break the `Send` requirements
/// so we need something with thread-safety. `TryLock` is a cheap alternative
/// that offers similar semantics to a `RefCell` but with synchronisation.
TryLock<Box<RequestContextInner>>,
TryLock<RequestContextInner>,
);
struct RequestContextInner {
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
impl Clone for RequestContext {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = Box::new(RequestContextInner {
let new = RequestContextInner {
conn_info: inner.conn_info.clone(),
session_id: inner.session_id,
protocol: inner.protocol,
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
disconnect_sender: None,
latency_timer: LatencyTimer::noop(inner.protocol),
disconnect_timestamp: inner.disconnect_timestamp,
});
};
Self(TryLock::new(new))
}
@@ -140,7 +140,7 @@ impl RequestContext {
role = tracing::field::Empty,
);
let inner = Box::new(RequestContextInner {
let inner = RequestContextInner {
conn_info,
session_id,
protocol,
@@ -168,7 +168,7 @@ impl RequestContext {
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
disconnect_timestamp: None,
});
};
Self(TryLock::new(inner))
}
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
}
}
pub struct DisconnectLogger(Box<RequestContextInner>);
pub struct DisconnectLogger(RequestContextInner);
impl Drop for DisconnectLogger {
fn drop(&mut self) {

View File

@@ -53,25 +53,6 @@ pub(crate) trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
impl<T: ConnectMechanism + Sync> ConnectMechanism for &T {
type Connection = T::Connection;
type ConnectError = T::ConnectError;
type Error = T::Error;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
T::connect_once(self, ctx, node_info, config).await
}
fn update_connect_config(&self, conf: &mut compute::ConnCfg) {
T::update_connect_config(self, conf);
}
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
@@ -124,8 +105,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: M,
backend: B,
mechanism: &M,
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
@@ -135,9 +116,9 @@ where
{
let mut num_retries = 0;
let mut node_info =
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.set_keys(backend.get_keys());
node_info.set_keys(user_info.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -178,7 +159,7 @@ where
let old_node_info = invalidate_cache(node_info);
// TODO: increment num_retries?
let mut node_info =
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);

View File

@@ -67,6 +67,7 @@ where
}
}
#[tracing::instrument(skip_all)]
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,

View File

@@ -5,7 +5,6 @@ use pq_proto::{
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
@@ -52,7 +51,7 @@ impl ReportableError for HandshakeError {
pub(crate) enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData, TaskTrackerToken),
Cancel(CancelKeyData),
}
/// Establish a (most probably, secure) connection with the client.
@@ -63,7 +62,6 @@ pub(crate) enum HandshakeData<S> {
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
stream: S,
tracker: TaskTrackerToken,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
@@ -73,7 +71,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
match msg {
@@ -159,13 +157,15 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
stream.framed = Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
},
read_buf,
write_buf,
};
}
}
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
break Ok(HandshakeData::Cancel(cancel_key_data));
}
}
}

View File

@@ -10,27 +10,26 @@ pub(crate) mod wake_compute;
use std::sync::Arc;
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use futures::TryFutureExt;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use passthrough::passthrough;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, format_smolstr};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info, warn};
use self::connect_compute::{TcpMechanism, connect_to_compute};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -71,6 +70,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -84,12 +84,12 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
let tracker = connections.token();
tokio::spawn(async move {
connections.spawn(async move {
let (socket, conn_info) = match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
@@ -138,41 +138,60 @@ pub async fn task_main(
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span();
let mut ctx = Some(ctx);
let res = handle_client(
config,
auth_backend,
&mut ctx,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
tracker,
cancellations,
)
.instrument(span)
.instrument(ctx.span())
.boxed()
.await;
match (ctx, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
@@ -239,79 +258,46 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx_slot: &mut Option<RequestContext>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let cplane = match auth_backend {
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let request_gauge = metrics.connection_requests.guard(protocol);
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let handshake_result: Result<_, ClientRequestError> = async {
let tls = config.tls_config.load();
let tls = tls.as_deref();
let tls = config.tls_config.load();
let tls = tls.as_deref();
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let data = {
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
tokio::time::timeout(
config.handshake_timeout,
handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
),
)
.await??
};
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
match data {
HandshakeData::Startup(mut stream, params) => {
ctx.set_db_options(params.clone());
let host = mode.hostname(stream.get_ref());
let cn = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, host, cn);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let session = cancellation_handler.get_key();
Ok(Some((stream, params, session, user_info)))
}
HandshakeData::Cancel(cancel_key_data, tracker) => {
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
// spawn a task to cancel the session, but don't wait for it
tokio::spawn(async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
cancellation_handler
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
@@ -319,108 +305,111 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.instrument(cancel_span)
.await
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
});
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
Ok(None)
}
return Ok(None);
}
}
.await;
let Some((mut stream, params, session, user_info)) = handshake_result? else {
return Ok(());
};
let ctx = ctx_slot.as_ref().expect("context must be set");
drop(pause);
let auth_result: Result<_, ClientRequestError> = async {
let user = user_info.user.clone();
ctx.set_db_options(params.clone());
match cplane
.authenticate(
ctx,
&mut stream,
user_info,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{
Ok(auth_result) => Ok(auth_result),
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?
}
}
}
.await;
let hostname = mode.hostname(stream.get_ref());
let compute_creds = auth_result?;
let common_names = tls.map(|tls| &tls.common_names);
let connect_result: Result<_, ClientRequestError> = async {
let compute_user_info = compute_creds.info.clone();
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let mut node = connect_to_compute(
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let user = user_info.get_user().to_owned();
let (user_info, _ip_allowlist) = match user_info
.authenticate(
ctx,
TcpMechanism {
user_info: compute_user_info,
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
auth::ControlPlaneWakeCompute {
cplane,
creds: compute_creds,
},
config.wake_compute_retry_config,
&config.connect_to_compute,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
return stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?;
}
};
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf, tracker) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
Ok((node, stream, tracker))
}
.await;
let (node, stream, tracker) = connect_result?;
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
Ok(())
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
}
/// Finish client connection initialization: confirm auth success, send params, etc.

View File

@@ -1,6 +1,5 @@
use smol_str::{SmolStr, ToSmolStr};
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use utils::measured_stream::MeasuredStream;
@@ -8,14 +7,13 @@ use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::protocol2::ConnectionInfoExtra;
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
@@ -63,53 +61,41 @@ pub(crate) async fn proxy_pass(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
ctx: RequestContext,
compute_config: &'static ComputeConfig,
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) aux: MetricsAuxInfo,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
client: Stream<S>,
compute: PostgresConnection,
cancel: cancellation::Session,
_req: NumConnectionRequestsGuard<'static>,
_conn: NumClientConnectionsGuard<'static>,
_tracker: TaskTrackerToken,
) {
let session_id = ctx.session_id();
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
let _disconnect = ctx.log_connect();
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
match res {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
tracing::warn!(
session_id = ?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
tracing::error!(
session_id = ?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
if let Err(err) = compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
}
// we don't need a result. If the queue is full, we just log the error
drop(cancel.remove_cancel_key());
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(
self.client,
self.compute.stream,
self.aux,
self.private_link_id,
)
.await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
res
}
}

View File

@@ -38,7 +38,6 @@ async fn proxy_mitm(
let (end_client, startup) = match handshake(
&RequestContext::test(),
client1,
TaskTracker::new().token(),
Some(&server_config1),
false,
)
@@ -46,7 +45,7 @@ async fn proxy_mitm(
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);

View File

@@ -15,7 +15,6 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio_util::task::TaskTracker;
use tracing_test::traced_test;
use super::connect_compute::ConnectMechanism;
@@ -179,12 +178,10 @@ async fn dummy_proxy(
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let t = TaskTracker::new().token();
let mut stream =
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
};
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;
@@ -625,7 +622,7 @@ async fn connect_to_compute_success() {
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -639,7 +636,7 @@ async fn connect_to_compute_retry() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -654,7 +651,7 @@ async fn connect_to_compute_non_retry_1() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -669,7 +666,7 @@ async fn connect_to_compute_non_retry_2() {
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -694,7 +691,7 @@ async fn connect_to_compute_non_retry_3() {
connect_to_compute(
&ctx,
&mechanism,
user_info,
&user_info,
wake_compute_retry_config,
&config,
)
@@ -712,7 +709,7 @@ async fn wake_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -727,7 +724,7 @@ async fn wake_non_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -746,7 +743,7 @@ async fn fail_but_wake_invalidates_cache() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
@@ -767,7 +764,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
@@ -788,7 +785,7 @@ async fn retry_but_wake_invalidates_cache() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
@@ -811,7 +808,7 @@ async fn retry_no_wake_skips_invalidation() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -224,13 +224,13 @@ impl PoolingBackend {
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
TokioMechanism {
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
backend,
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
@@ -268,13 +268,13 @@ impl PoolingBackend {
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
HyperMechanism {
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
backend,
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)

View File

@@ -70,6 +70,34 @@ pub(crate) enum JsonConversionError {
ParseJsonError(#[from] serde_json::Error),
#[error("unbalanced array")]
UnbalancedArray,
#[error("unbalanced quoted string")]
UnbalancedString,
}
enum OutputMode {
Array(Vec<Value>),
Object(Map<String, Value>),
}
impl OutputMode {
fn key(&mut self, key: &str) -> &mut Value {
match self {
OutputMode::Array(values) => push_entry(values, Value::Null),
OutputMode::Object(map) => map.entry(key.to_string()).or_insert(Value::Null),
}
}
fn finish(self) -> Value {
match self {
OutputMode::Array(values) => Value::Array(values),
OutputMode::Object(map) => Value::Object(map),
}
}
}
fn push_entry<T>(arr: &mut Vec<T>, t: T) -> &mut T {
arr.push(t);
arr.last_mut().expect("a value was just inserted")
}
//
@@ -77,182 +105,276 @@ pub(crate) enum JsonConversionError {
//
pub(crate) fn pg_text_row_to_json(
row: &Row,
columns: &[Type],
raw_output: bool,
array_mode: bool,
) -> Result<Value, JsonConversionError> {
let iter = row
.columns()
.iter()
.zip(columns)
.enumerate()
.map(|(i, (column, typ))| {
let name = column.name();
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, typ)?
};
Ok((name.to_string(), json_value))
});
if array_mode {
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, JsonConversionError>>()?;
Ok(Value::Array(arr))
let mut entries = if array_mode {
OutputMode::Array(Vec::with_capacity(row.columns().len()))
} else {
let obj = iter.collect::<Result<Map<String, Value>, JsonConversionError>>()?;
Ok(Value::Object(obj))
OutputMode::Object(Map::with_capacity(row.columns().len()))
};
for (i, column) in row.columns().iter().enumerate() {
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
let value = entries.key(column.name());
match pg_value {
Some(v) if raw_output => *value = Value::String(v.to_string()),
Some(v) => pg_text_to_json(value, v, column.type_())?,
None => *value = Value::Null,
}
}
Ok(entries.finish())
}
//
// Convert postgres text-encoded value to JSON value
//
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, JsonConversionError> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
}
fn pg_text_to_json(
output: &mut Value,
val: &str,
pg_type: &Type,
) -> Result<(), JsonConversionError> {
if let Kind::Array(elem_type) = pg_type.kind() {
// todo: we should fetch this from postgres.
let delimiter = ',';
match *pg_type {
Type::BOOL => Ok(Value::Bool(val == "t")),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
Ok(Value::Number(serde_json::Number::from(val)))
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
Ok(Value::Number(num))
} else {
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
Ok(Value::String(val.to_string()))
}
}
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
_ => Ok(Value::String(val.to_string())),
}
} else {
Ok(Value::Null)
}
}
//
// Parse postgres array into JSON array.
//
// This is a bit involved because we need to handle nested arrays and quoted
// values. Unlike postgres we don't check that all nested arrays have the same
// dimensions, we just return them as is.
//
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, JsonConversionError> {
pg_array_parse_inner(pg_array, elem_type, false).map(|(v, _)| v)
}
fn pg_array_parse_inner(
pg_array: &str,
elem_type: &Type,
nested: bool,
) -> Result<(Value, usize), JsonConversionError> {
let mut pg_array_chr = pg_array.char_indices();
let mut level = 0;
let mut quote = false;
let mut entries: Vec<Value> = Vec::new();
let mut entry = String::new();
// skip bounds decoration
if let Some('[') = pg_array.chars().next() {
for (_, c) in pg_array_chr.by_ref() {
if c == '=' {
break;
}
}
let mut array = vec![];
pg_array_parse(&mut array, val, elem_type, delimiter)?;
*output = Value::Array(array);
return Ok(());
}
fn push_checked(
entry: &mut String,
entries: &mut Vec<Value>,
elem_type: &Type,
) -> Result<(), JsonConversionError> {
if !entry.is_empty() {
// While in usual postgres response we get nulls as None and everything else
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
// string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
// here while we have quotation info and convert them to None.
if entry == "NULL" {
entries.push(pg_text_to_json(None, elem_type)?);
match *pg_type {
Type::BOOL => *output = Value::Bool(val == "t"),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
*output = Value::Number(serde_json::Number::from(val));
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
*output = Value::Number(num);
} else {
entries.push(pg_text_to_json(Some(entry), elem_type)?);
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
*output = Value::String(val.to_string());
}
entry.clear();
}
Ok(())
Type::JSON | Type::JSONB => *output = serde_json::from_str(val)?,
_ => *output = Value::String(val.to_string()),
}
while let Some((mut i, mut c)) = pg_array_chr.next() {
let mut escaped = false;
Ok(())
}
if c == '\\' {
escaped = true;
let Some(x) = pg_array_chr.next() else {
return Err(JsonConversionError::UnbalancedArray);
};
(i, c) = x;
}
match c {
'{' if !quote => {
level += 1;
if level > 1 {
let (res, off) = pg_array_parse_inner(&pg_array[i..], elem_type, true)?;
entries.push(res);
for _ in 0..off - 1 {
pg_array_chr.next();
}
}
}
'}' if !quote => {
level -= 1;
if level == 0 {
push_checked(&mut entry, &mut entries, elem_type)?;
if nested {
return Ok((Value::Array(entries), i));
}
}
}
'"' if !escaped => {
if quote {
// end of quoted string, so push it manually without any checks
// for emptiness or nulls
entries.push(pg_text_to_json(Some(&entry), elem_type)?);
entry.clear();
}
quote = !quote;
}
',' if !quote => {
push_checked(&mut entry, &mut entries, elem_type)?;
}
_ => {
entry.push(c);
}
}
/// Parse postgres array into JSON array.
///
/// This is a bit involved because we need to handle nested arrays and quoted
/// values. Unlike postgres we don't check that all nested arrays have the same
/// dimensions, we just return them as is.
///
/// <https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO>
///
/// The external text representation of an array value consists of items that are interpreted
/// according to the I/O conversion rules for the array's element type, plus decoration that
/// indicates the array structure. The decoration consists of curly braces (`{` and `}`) around
/// the array value plus delimiter characters between adjacent items. The delimiter character
/// is usually a comma (,) but can be something else: it is determined by the typdelim setting
/// for the array's element type. Among the standard data types provided in the PostgreSQL
/// distribution, all use a comma, except for type box, which uses a semicolon (;).
///
/// In a multidimensional array, each dimension (row, plane, cube, etc.)
/// gets its own level of curly braces, and delimiters must be written between adjacent
/// curly-braced entities of the same level.
fn pg_array_parse(
elements: &mut Vec<Value>,
mut pg_array: &str,
elem: &Type,
delim: char,
) -> Result<(), JsonConversionError> {
// skip bounds decoration, eg:
// `[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}`
// technically these are significant, but we have no way to represent them in json.
if let Some('[') = pg_array.chars().next() {
let Some((_bounds, array)) = pg_array.split_once('=') else {
return Err(JsonConversionError::UnbalancedArray);
};
pg_array = array;
}
if level != 0 {
// whitespace might preceed a `{`.
let pg_array = pg_array.trim_start();
let rest = pg_array_parse_inner(elements, pg_array, elem, delim)?;
if !rest.is_empty() {
return Err(JsonConversionError::UnbalancedArray);
}
Ok((Value::Array(entries), 0))
Ok(())
}
/// reads a single array from the `pg_array` string and pushes each values to `elements`.
/// returns the rest of the `pg_array` string that was not read.
fn pg_array_parse_inner<'a>(
elements: &mut Vec<Value>,
mut pg_array: &'a str,
elem: &Type,
delim: char,
) -> Result<&'a str, JsonConversionError> {
// array should have a `{` prefix.
pg_array = pg_array
.strip_prefix('{')
.ok_or(JsonConversionError::UnbalancedArray)?;
let mut q = String::new();
loop {
let value = push_entry(elements, Value::Null);
pg_array = pg_array_parse_item(value, &mut q, pg_array, elem, delim)?;
// check for separator.
if let Some(next) = pg_array.strip_prefix(delim) {
// next item.
pg_array = next;
} else {
break;
}
}
let Some(next) = pg_array.strip_prefix('}') else {
// missing `}` terminator.
return Err(JsonConversionError::UnbalancedArray);
};
// whitespace might follow a `}`.
Ok(next.trim_start())
}
/// reads a single item from the `pg_array` string.
/// returns the rest of the `pg_array` string that was not read.
///
/// `quoted` is a scratch allocation that has no defined output.
fn pg_array_parse_item<'a>(
output: &mut Value,
quoted: &mut String,
mut pg_array: &'a str,
elem: &Type,
delim: char,
) -> Result<&'a str, JsonConversionError> {
// We are trying to parse an array item.
// This could be a new array, if this is a multi-dimentional array.
// This could be a quoted string representing `elem`.
// This could be an unquoted string representing `elem`.
// whitespace might preceed an item.
pg_array = pg_array.trim_start();
if pg_array.strip_prefix('{').is_some() {
// nested array.
let mut nested = vec![];
pg_array = pg_array_parse_inner(&mut nested, pg_array, elem, delim)?;
*output = Value::Array(nested);
return Ok(pg_array);
}
if let Some(mut pg_array) = pg_array.strip_prefix('"') {
pg_array = pg_array_parse_quoted(quoted, pg_array)?;
// we have unquoted an item string:
pg_text_to_json(output, quoted, elem)?;
quoted.clear();
return Ok(pg_array);
}
// we need to parse an item. read until we find a delimiter or `}`.
let index = pg_array
.find([delim, '}'])
.ok_or(JsonConversionError::UnbalancedArray)?;
let item;
(item, pg_array) = pg_array.split_at(index);
// item might have trailing whitespace that we need to ignore.
let item = item.trim_end();
// we might have an item string:
// check for null
if item == "NULL" {
*output = Value::Null;
} else {
pg_text_to_json(output, item, elem)?;
}
Ok(pg_array)
}
/// reads a single quoted item from the `pg_array` string.
///
/// Returns the rest of the `pg_array` string that was not read.
/// The output is written into `quoted`.
///
/// The pg_array string must have a `"` terminator, but the `"` initial value
/// must have already been removed from the input. The terminator is removed.
fn pg_array_parse_quoted<'a>(
quoted: &mut String,
mut pg_array: &'a str,
) -> Result<&'a str, JsonConversionError> {
// The array output routine will put double quotes around element values if they are empty strings,
// contain curly braces, delimiter characters, double quotes, backslashes, or white space,
// or match the word `NULL`. Double quotes and backslashes embedded in element values will be backslash-escaped.
// For numeric data types it is safe to assume that double quotes will never appear,
// but for textual data types one should be prepared to cope with either the presence or absence of quotes.
// We write to quoted in chunks terminated by an escape character.
// Eg if we have the input `foo\"bar"`, then we write `foo`, then `"`, then finally `bar`.
loop {
// we need to parse an chunk. read until we find a '\\' or `"`.
let i = pg_array
.find(['\\', '"'])
.ok_or(JsonConversionError::UnbalancedString)?;
let chunk: &str;
(chunk, pg_array) = pg_array
.split_at_checked(i)
.expect("i is guaranteed to be in-bounds of pg_array");
// push the chunk.
quoted.push_str(chunk);
// consume the chunk_end character.
let chunk_end: char;
(chunk_end, pg_array) =
split_first_char(pg_array).expect("pg_array should start with either '\\\\' or '\"'");
// finished.
if chunk_end == '"' {
// whitespace might follow the '"'.
pg_array = pg_array.trim_start();
break Ok(pg_array);
}
// consume the escaped character.
let escaped: char;
(escaped, pg_array) =
split_first_char(pg_array).ok_or(JsonConversionError::UnbalancedString)?;
quoted.push(escaped);
}
}
fn split_first_char(s: &str) -> Option<(char, &str)> {
let mut chars = s.chars();
let c = chars.next()?;
Some((c, chars.as_str()))
}
#[cfg(test)]
@@ -316,37 +438,33 @@ mod tests {
);
}
fn pg_text_to_json(val: &str, pg_type: &Type) -> Value {
let mut v = Value::Null;
super::pg_text_to_json(&mut v, val, pg_type).unwrap();
v
}
fn pg_array_parse(pg_array: &str, pg_type: &Type) -> Value {
let mut array = vec![];
super::pg_array_parse(&mut array, pg_array, pg_type, ',').unwrap();
Value::Array(array)
}
#[test]
fn test_atomic_types_parse() {
assert_eq!(pg_text_to_json("foo", &Type::TEXT), json!("foo"));
assert_eq!(pg_text_to_json("42", &Type::INT4), json!(42));
assert_eq!(pg_text_to_json("42", &Type::INT2), json!(42));
assert_eq!(pg_text_to_json("42", &Type::INT8), json!("42"));
assert_eq!(pg_text_to_json("42.42", &Type::FLOAT8), json!(42.42));
assert_eq!(pg_text_to_json("42.42", &Type::FLOAT4), json!(42.42));
assert_eq!(pg_text_to_json("NaN", &Type::FLOAT4), json!("NaN"));
assert_eq!(
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
json!("foo")
);
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
assert_eq!(
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
json!("42")
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
json!("NaN")
);
assert_eq!(
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json("Infinity", &Type::FLOAT4),
json!("Infinity")
);
assert_eq!(
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json("-Infinity", &Type::FLOAT4),
json!("-Infinity")
);
@@ -355,10 +473,9 @@ mod tests {
.unwrap();
assert_eq!(
pg_text_to_json(
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#,
&Type::JSONB
)
.unwrap(),
),
json
);
}
@@ -366,7 +483,7 @@ mod tests {
#[test]
fn test_pg_array_parse_text() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
pg_array_parse(pg_arr, &Type::TEXT)
}
assert_eq!(
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
@@ -389,7 +506,7 @@ mod tests {
#[test]
fn test_pg_array_parse_bool() {
fn pb(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
pg_array_parse(pg_arr, &Type::BOOL)
}
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
@@ -406,7 +523,7 @@ mod tests {
#[test]
fn test_pg_array_parse_numbers() {
fn pn(pg_arr: &str, ty: &Type) -> Value {
pg_array_parse(pg_arr, ty).unwrap()
pg_array_parse(pg_arr, ty)
}
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
@@ -434,7 +551,7 @@ mod tests {
#[test]
fn test_pg_array_with_decoration() {
fn p(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::INT2).unwrap()
pg_array_parse(pg_arr, &Type::INT2)
}
assert_eq!(
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
@@ -445,7 +562,7 @@ mod tests {
#[test]
fn test_pg_array_parse_json() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
pg_array_parse(pg_arr, &Type::JSONB)
}
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
assert_eq!(

View File

@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tokio_util::task::TaskTracker;
use tracing::{Instrument, info, warn};
use crate::cancellation::CancellationHandler;
@@ -124,6 +124,7 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -149,11 +150,11 @@ pub async fn task_main(
let conn_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let tracker = connections.token();
tokio::spawn(
let cancellations = cancellations.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
@@ -180,7 +181,8 @@ pub async fn task_main(
Box::pin(connection_handler(
config,
backend,
tracker,
connections2,
cancellations,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
@@ -303,7 +305,8 @@ async fn connection_startup(
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
tracker: TaskTrackerToken,
connections: TaskTracker,
cancellations: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
@@ -344,17 +347,19 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = tokio::spawn(
let cancellations = cancellations.clone();
let handler = connections.spawn(
request_handler(
req,
config,
backend.clone(),
tracker.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -395,13 +400,14 @@ async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
tracker: TaskTrackerToken,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
session_id: uuid::Uuid,
conn_info: ConnectionInfo,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -435,17 +441,10 @@ async fn request_handler(
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
tokio::spawn(
let cancellations = cancellations.clone();
ws_connections.spawn(
async move {
let websocket = match websocket.await {
Err(e) => {
warn!("could not upgrade websocket connection: {e:#}");
return;
}
Ok(websocket) => websocket,
};
websocket::serve_websocket(
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
@@ -453,9 +452,12 @@ async fn request_handler(
cancellation_handler,
endpoint_rate_limiter,
host,
tracker,
cancellations,
)
.await;
.await
{
warn!("error in websocket connection: {e:#}");
}
}
.instrument(span),
);

View File

@@ -1102,7 +1102,6 @@ async fn query_to_json<T: GenericClient>(
let columns_len = row_stream.statement.columns().len();
let mut fields = Vec::with_capacity(columns_len);
let mut types = Vec::with_capacity(columns_len);
for c in row_stream.statement.columns() {
fields.push(json!({
@@ -1114,8 +1113,6 @@ async fn query_to_json<T: GenericClient>(
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
types.push(c.type_().clone());
}
let raw_output = parsed_headers.raw_output;
@@ -1137,7 +1134,7 @@ async fn query_to_json<T: GenericClient>(
));
}
let row = pg_text_row_to_json(&row, &types, raw_output, array_mode)?;
let row = pg_text_row_to_json(&row, raw_output, array_mode)?;
rows.push(row);
// assumption: parsing pg text and converting to json takes CPU time.

View File

@@ -2,14 +2,14 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::warn;
use crate::cancellation::CancellationHandler;
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, handle_client};
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
@@ -128,12 +128,13 @@ pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ctx: RequestContext,
websocket: Upgraded,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
tracker: TaskTrackerToken,
) {
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
@@ -141,28 +142,36 @@ pub(crate) async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let mut ctx_slot = Some(ctx);
let res = handle_client(
let res = Box::pin(handle_client(
config,
auth_backend,
&mut ctx_slot,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
tracker,
)
cancellations,
))
.await;
match (ctx_slot, res) {
(None, _) => {}
(Some(ctx), Err(e)) => {
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
Err(e.into())
}
(Some(ctx), Ok(())) => {
Ok(None) => {
ctx.set_success();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}
}
}

View File

@@ -10,7 +10,6 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use crate::control_plane::messages::ColdStartInfo;
@@ -25,22 +24,19 @@ use crate::tls::TlsServerEndPoint;
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
pub(crate) framed: Framed<S>,
pub(crate) tracker: TaskTrackerToken,
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
pub fn new(stream: S) -> Self {
Self {
framed: Framed::new(stream),
tracker,
}
}
/// Extract the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
let (stream, read) = self.framed.into_inner();
(stream, read, self.tracker)
pub fn into_inner(self) -> (S, BytesMut) {
self.framed.into_inner()
}
/// Get a shared reference to the underlying stream.

View File

@@ -44,7 +44,6 @@ struct GlobalTimelinesState {
// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as
// this map is dropped on restart.
tombstones: HashMap<TenantTimelineId, Instant>,
tenant_tombstones: HashMap<TenantId, Instant>,
conf: Arc<SafeKeeperConf>,
broker_active_set: Arc<TimelinesSet>,
@@ -82,25 +81,10 @@ impl GlobalTimelinesState {
}
}
fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool {
self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id)
}
/// Removes all blocking tombstones for the given timeline ID.
/// Returns `true` if there have been actual changes.
fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool {
self.tombstones.remove(ttid).is_some()
|| self.tenant_tombstones.remove(&ttid.tenant_id).is_some()
}
fn delete(&mut self, ttid: TenantTimelineId) {
self.timelines.remove(&ttid);
self.tombstones.insert(ttid, Instant::now());
}
fn add_tenant_tombstone(&mut self, tenant_id: TenantId) {
self.tenant_tombstones.insert(tenant_id, Instant::now());
}
}
/// A struct used to manage access to the global timelines map.
@@ -115,7 +99,6 @@ impl GlobalTimelines {
state: Mutex::new(GlobalTimelinesState {
timelines: HashMap::new(),
tombstones: HashMap::new(),
tenant_tombstones: HashMap::new(),
conf,
broker_active_set: Arc::new(TimelinesSet::default()),
global_rate_limiter: RateLimiter::new(1, 1),
@@ -262,7 +245,7 @@ impl GlobalTimelines {
return Ok(timeline);
}
if state.has_tombstone(&ttid) {
if state.tombstones.contains_key(&ttid) {
anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate");
}
@@ -312,14 +295,13 @@ impl GlobalTimelines {
_ => {}
}
if check_tombstone {
if state.has_tombstone(&ttid) {
if state.tombstones.contains_key(&ttid) {
anyhow::bail!("timeline {ttid} is deleted, refusing to recreate");
}
} else {
// We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust
// that the human doing this manual intervention knows what they are doing, and remove its tombstone.
// It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed.
if state.remove_tombstone(&ttid) {
if state.tombstones.remove(&ttid).is_some() {
warn!("un-deleted timeline {ttid}");
}
}
@@ -500,7 +482,6 @@ impl GlobalTimelines {
let tli_res = {
let state = self.state.lock().unwrap();
// Do NOT check tenant tombstones here: those were set earlier
if state.tombstones.contains_key(ttid) {
// Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do.
info!("Timeline {ttid} was already deleted");
@@ -576,10 +557,6 @@ impl GlobalTimelines {
action: DeleteOrExclude,
) -> Result<HashMap<TenantTimelineId, TimelineDeleteResult>> {
info!("deleting all timelines for tenant {}", tenant_id);
// Adding a tombstone before getting the timelines to prevent new timeline additions
self.state.lock().unwrap().add_tenant_tombstone(*tenant_id);
let to_delete = self.get_all_for_tenant(*tenant_id);
let mut err = None;
@@ -623,9 +600,6 @@ impl GlobalTimelines {
state
.tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
state
.tenant_tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
}
}

View File

@@ -87,7 +87,6 @@ impl WalProposer {
let config = Config {
ttid,
safekeepers_list: addrs,
safekeeper_conninfo_options: String::new(),
safekeeper_reconnect_timeout: 1000,
safekeeper_connection_timeout: 5000,
sync_safekeepers,

View File

@@ -482,10 +482,6 @@ async fn handle_tenant_timeline_delete(
ForwardOutcome::NotForwarded(_req) => {}
};
service
.maybe_delete_timeline_import(tenant_id, timeline_id)
.await?;
// For timeline deletions, which both implement an "initially return 202, then 404 once
// we're done" semantic, we wrap with a retry loop to expose a simpler API upstream.
async fn deletion_wrapper<R, F>(service: Arc<Service>, f: F) -> Result<Response<Body>, ApiError>

View File

@@ -99,8 +99,8 @@ use crate::tenant_shard::{
ScheduleOptimization, ScheduleOptimizationAction, TenantShard,
};
use crate::timeline_import::{
FinalizingImport, ImportResult, ShardImportStatuses, TimelineImport,
TimelineImportFinalizeError, TimelineImportState, UpcallClient,
ImportResult, ShardImportStatuses, TimelineImport, TimelineImportFinalizeError,
TimelineImportState, UpcallClient,
};
const WAITER_FILL_DRAIN_POLL_TIMEOUT: Duration = Duration::from_millis(500);
@@ -232,9 +232,6 @@ struct ServiceState {
/// Queue of tenants who are waiting for concurrency limits to permit them to reconcile
delayed_reconcile_rx: tokio::sync::mpsc::Receiver<TenantShardId>,
/// Tracks ongoing timeline import finalization tasks
imports_finalizing: BTreeMap<(TenantId, TimelineId), FinalizingImport>,
}
/// Transform an error from a pageserver into an error to return to callers of a storage
@@ -311,7 +308,6 @@ impl ServiceState {
scheduler,
ongoing_operation: None,
delayed_reconcile_rx,
imports_finalizing: Default::default(),
}
}
@@ -3827,13 +3823,6 @@ impl Service {
.await;
failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock");
let is_import = create_req.is_import();
let read_only = matches!(
create_req.mode,
models::TimelineCreateRequestMode::Branch {
read_only: true,
..
}
);
if is_import {
// Ensure that there is no split on-going.
@@ -3906,13 +3895,13 @@ impl Service {
}
None
} else if safekeepers || read_only {
} else if safekeepers {
// Note that for imported timelines, we do not create the timeline on the safekeepers
// straight away. Instead, we do it once the import finalized such that we know what
// start LSN to provide for the safekeepers. This is done in
// [`Self::finalize_timeline_import`].
let res = self
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only)
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info)
.instrument(tracing::info_span!("timeline_create_safekeepers", %tenant_id, timeline_id=%timeline_info.timeline_id))
.await?;
Some(res)
@@ -3926,11 +3915,6 @@ impl Service {
})
}
#[instrument(skip_all, fields(
tenant_id=%req.tenant_shard_id.tenant_id,
shard_id=%req.tenant_shard_id.shard_slug(),
timeline_id=%req.timeline_id,
))]
pub(crate) async fn handle_timeline_shard_import_progress(
self: &Arc<Self>,
req: TimelineImportStatusRequest,
@@ -3980,11 +3964,6 @@ impl Service {
})
}
#[instrument(skip_all, fields(
tenant_id=%req.tenant_shard_id.tenant_id,
shard_id=%req.tenant_shard_id.shard_slug(),
timeline_id=%req.timeline_id,
))]
pub(crate) async fn handle_timeline_shard_import_progress_upcall(
self: &Arc<Self>,
req: PutTimelineImportStatusRequest,
@@ -4101,58 +4080,13 @@ impl Service {
///
/// If this method gets pre-empted by shut down, it will be called again at start-up (on-going
/// imports are stored in the database).
///
/// # Cancel-Safety
/// Not cancel safe.
/// If the caller stops polling, the import will not be removed from
/// [`ServiceState::imports_finalizing`].
#[instrument(skip_all, fields(
tenant_id=%import.tenant_id,
timeline_id=%import.timeline_id,
))]
async fn finalize_timeline_import(
self: &Arc<Self>,
import: TimelineImport,
) -> Result<(), TimelineImportFinalizeError> {
let tenant_timeline = (import.tenant_id, import.timeline_id);
let (_finalize_import_guard, cancel) = {
let mut locked = self.inner.write().unwrap();
let gate = Gate::default();
let cancel = CancellationToken::default();
let guard = gate.enter().unwrap();
locked.imports_finalizing.insert(
tenant_timeline,
FinalizingImport {
gate,
cancel: cancel.clone(),
},
);
(guard, cancel)
};
let res = tokio::select! {
res = self.finalize_timeline_import_impl(import) => {
res
},
_ = cancel.cancelled() => {
Err(TimelineImportFinalizeError::Cancelled)
}
};
let mut locked = self.inner.write().unwrap();
locked.imports_finalizing.remove(&tenant_timeline);
res
}
async fn finalize_timeline_import_impl(
self: &Arc<Self>,
import: TimelineImport,
) -> Result<(), TimelineImportFinalizeError> {
tracing::info!("Finalizing timeline import");
@@ -4352,46 +4286,6 @@ impl Service {
.await;
}
/// Delete a timeline import if it exists
///
/// Firstly, delete the entry from the database. Any updates
/// from pageservers after the update will fail with a 404, so the
/// import cannot progress into finalizing state if it's not there already.
/// Secondly, cancel the finalization if one is in progress.
pub(crate) async fn maybe_delete_timeline_import(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<(), DatabaseError> {
let tenant_has_ongoing_import = {
let locked = self.inner.read().unwrap();
locked
.tenants
.range(TenantShardId::tenant_range(tenant_id))
.any(|(_tid, shard)| shard.importing == TimelineImportState::Importing)
};
if !tenant_has_ongoing_import {
return Ok(());
}
self.persistence
.delete_timeline_import(tenant_id, timeline_id)
.await?;
let maybe_finalizing = {
let mut locked = self.inner.write().unwrap();
locked.imports_finalizing.remove(&(tenant_id, timeline_id))
};
if let Some(finalizing) = maybe_finalizing {
finalizing.cancel.cancel();
finalizing.gate.close().await;
}
Ok(())
}
pub(crate) async fn tenant_timeline_archival_config(
&self,
tenant_id: TenantId,
@@ -8627,9 +8521,8 @@ impl Service {
Some(ShardCount(new_shard_count))
}
/// Fetches the top tenant shards from every available node, in descending order of
/// max logical size. Offline nodes are skipped, and any errors from available nodes
/// will be logged and ignored.
/// Fetches the top tenant shards from every node, in descending order of
/// max logical size. Any node errors will be logged and ignored.
async fn get_top_tenant_shards(
&self,
request: &TopTenantShardsRequest,
@@ -8640,7 +8533,6 @@ impl Service {
.unwrap()
.nodes
.values()
.filter(|node| node.is_available())
.cloned()
.collect_vec();

View File

@@ -208,7 +208,6 @@ impl Service {
self: &Arc<Self>,
tenant_id: TenantId,
timeline_info: &TimelineInfo,
read_only: bool,
) -> Result<SafekeepersInfo, ApiError> {
let timeline_id = timeline_info.timeline_id;
let pg_version = timeline_info.pg_version * 10000;
@@ -221,11 +220,7 @@ impl Service {
let start_lsn = timeline_info.last_record_lsn;
// Choose initial set of safekeepers respecting affinity
let sks = if !read_only {
self.safekeepers_for_new_timeline().await?
} else {
Vec::new()
};
let sks = self.safekeepers_for_new_timeline().await?;
let sks_persistence = sks.iter().map(|sk| sk.id.0 as i64).collect::<Vec<_>>();
// Add timeline to db
let mut timeline_persist = TimelinePersistence {
@@ -258,16 +253,6 @@ impl Service {
)));
}
}
let ret = SafekeepersInfo {
generation: timeline_persist.generation as u32,
safekeepers: sks.clone(),
tenant_id,
timeline_id,
};
if read_only {
return Ok(ret);
}
// Create the timeline on a quorum of safekeepers
let remaining = self
.tenant_timeline_create_safekeepers_quorum(
@@ -331,7 +316,12 @@ impl Service {
}
}
Ok(ret)
Ok(SafekeepersInfo {
generation: timeline_persist.generation as u32,
safekeepers: sks,
tenant_id,
timeline_id,
})
}
pub(crate) async fn tenant_timeline_create_safekeepers_until_success(
@@ -346,10 +336,8 @@ impl Service {
return Err(TimelineImportFinalizeError::ShuttingDown);
}
// This function is only used in non-read-only scenarios
let read_only = false;
let res = self
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only)
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info)
.await;
match res {
@@ -422,18 +410,6 @@ impl Service {
.chain(tl.sk_set.iter())
.collect::<HashSet<_>>();
// The timeline has no safekeepers: we need to delete it from the db manually,
// as no safekeeper reconciler will get to it
if all_sks.is_empty() {
if let Err(err) = self
.persistence
.delete_timeline(tenant_id, timeline_id)
.await
{
tracing::warn!(%tenant_id, %timeline_id, "couldn't delete timeline from db: {err}");
}
}
// Schedule reconciliations
for &sk_id in all_sks.iter() {
let pending_op = TimelinePendingOpPersistence {

View File

@@ -7,7 +7,6 @@ use serde::{Deserialize, Serialize};
use pageserver_api::models::{ShardImportProgress, ShardImportStatus};
use tokio_util::sync::CancellationToken;
use utils::sync::gate::Gate;
use utils::{
id::{TenantId, TimelineId},
shard::ShardIndex,
@@ -56,8 +55,6 @@ pub(crate) enum TimelineImportUpdateFollowUp {
pub(crate) enum TimelineImportFinalizeError {
#[error("Shut down interrupted import finalize")]
ShuttingDown,
#[error("Import finalization was cancelled")]
Cancelled,
#[error("Mismatched shard detected during import finalize: {0}")]
MismatchedShards(ShardIndex),
}
@@ -167,11 +164,6 @@ impl TimelineImport {
}
}
pub(crate) struct FinalizingImport {
pub(crate) gate: Gate,
pub(crate) cancel: CancellationToken,
}
pub(crate) type ImportResult = Result<(), String>;
pub(crate) struct UpcallClient {

View File

@@ -1,4 +1,3 @@
import json
import os
import shutil
import subprocess
@@ -12,7 +11,6 @@ from _pytest.config import Config
from fixtures.log_helper import log
from fixtures.neon_cli import AbstractNeonCli
from fixtures.neon_fixtures import Endpoint, VanillaPostgres
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import MockS3Server
@@ -163,57 +161,3 @@ def fast_import(
f.write(fi.cmd.stderr)
log.info("Written logs to %s", test_output_dir)
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
"""
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
"""
assert not vanilla_pg.is_running()
path.mkdir()
# what cplane writes before scheduling fast_import
specpath = path / "spec.json"
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
# what fast_import writes
vanilla_pg.pgdatadir.rename(path / "pgdata")
statusdir = path / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
def populate_vanilla_pg(vanilla_pg: VanillaPostgres, target_relblock_size: int) -> int:
assert vanilla_pg.is_running()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
# fillfactor so we don't need to produce that much data
# 900 byte per row is > 10% => 1 row per page
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
nrows = 0
while True:
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
log.info(
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
)
if relblock_size >= target_relblock_size:
break
addrows = int((target_relblock_size - relblock_size) // 8192)
assert addrows >= 1, "forward progress"
vanilla_pg.safe_psql(
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
)
nrows += addrows
return nrows
def validate_import_from_vanilla_pg(endpoint: Endpoint, nrows: int):
assert endpoint.safe_psql_many(
[
"set effective_io_concurrency=32;",
"SET statement_timeout='300s';",
"select count(*), sum(data::bigint)::bigint from t",
]
) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]]

View File

@@ -404,29 +404,6 @@ class PageserverTracingConfig:
return ("tracing", value)
@dataclass
class PageserverImportConfig:
import_job_concurrency: int
import_job_soft_size_limit: int
import_job_checkpoint_threshold: int
@staticmethod
def default() -> PageserverImportConfig:
return PageserverImportConfig(
import_job_concurrency=4,
import_job_soft_size_limit=512 * 1024,
import_job_checkpoint_threshold=4,
)
def to_config_key_value(self) -> tuple[str, dict[str, Any]]:
value = {
"import_job_concurrency": self.import_job_concurrency,
"import_job_soft_size_limit": self.import_job_soft_size_limit,
"import_job_checkpoint_threshold": self.import_job_checkpoint_threshold,
}
return ("timeline_import_config", value)
class NeonEnvBuilder:
"""
Builder object to create a Neon runtime environment
@@ -477,7 +454,6 @@ class NeonEnvBuilder:
pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None,
pageserver_get_vectored_concurrent_io: str | None = None,
pageserver_tracing_config: PageserverTracingConfig | None = None,
pageserver_import_config: PageserverImportConfig | None = None,
):
self.repo_dir = repo_dir
self.rust_log_override = rust_log_override
@@ -535,7 +511,6 @@ class NeonEnvBuilder:
)
self.pageserver_tracing_config = pageserver_tracing_config
self.pageserver_import_config = pageserver_import_config
self.pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = (
pageserver_default_tenant_config_compaction_algorithm
@@ -1204,10 +1179,6 @@ class NeonEnv:
self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol
self.pageserver_get_vectored_concurrent_io = config.pageserver_get_vectored_concurrent_io
self.pageserver_tracing_config = config.pageserver_tracing_config
if config.pageserver_import_config is None:
self.pageserver_import_config = PageserverImportConfig.default()
else:
self.pageserver_import_config = config.pageserver_import_config
# Create the neon_local's `NeonLocalInitConf`
cfg: dict[str, Any] = {
@@ -1287,6 +1258,12 @@ class NeonEnv:
"no_sync": True,
# Look for gaps in WAL received from safekeepeers
"validate_wal_contiguity": True,
# TODO(vlad): make these configurable through the builder
"timeline_import_config": {
"import_job_concurrency": 4,
"import_job_soft_size_limit": 512 * 1024,
"import_job_checkpoint_threshold": 4,
},
}
# Batching (https://github.com/neondatabase/neon/issues/9377):
@@ -1348,12 +1325,6 @@ class NeonEnv:
ps_cfg[key] = value
if self.pageserver_import_config is not None:
key, value = self.pageserver_import_config.to_config_key_value()
if key not in ps_cfg:
ps_cfg[key] = value
# Create a corresponding NeonPageserver object
ps = NeonPageserver(
self, ps_id, port=pageserver_port, az_id=ps_cfg["availability_zone"]
@@ -2337,22 +2308,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN),
)
def import_status(
self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, generation: int
):
payload = {
"tenant_shard_id": str(tenant_shard_id),
"timeline_id": str(timeline_id),
"generation": generation,
}
self.request(
"GET",
f"{self.api}/upcall/v1/timeline_import_status",
headers=self.headers(TokenScope.GENERATIONS_API),
json=payload,
)
def reconcile_all(self):
r = self.request(
"POST",
@@ -2829,11 +2784,6 @@ class NeonPageserver(PgProtocol, LogUtils):
if self.running:
self.http_client().configure_failpoints([(name, action)])
def clear_persistent_failpoint(self, name: str):
del self._persistent_failpoints[name]
if self.running:
self.http_client().configure_failpoints([(name, "off")])
def timeline_dir(
self,
tenant_shard_id: TenantId | TenantShardId,

View File

@@ -675,7 +675,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
def timeline_delete(
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, **kwargs
) -> int:
):
"""
Note that deletion is not instant, it is scheduled and performed mostly in the background.
So if you need to wait for it to complete use `timeline_delete_wait_completed`.
@@ -688,8 +688,6 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
res_json = res.json()
assert res_json is None
return res.status_code
def timeline_gc(
self,
tenant_id: TenantId | TenantShardId,

View File

@@ -1,41 +1,31 @@
from __future__ import annotations
import enum
import json
import time
from collections import Counter
from dataclasses import dataclass
from enum import StrEnum
from threading import Event
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.fast_import import mock_import_bucket, populate_vanilla_pg
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
NeonPageserver,
PgBin,
VanillaPostgres,
wait_for_last_flush_lsn,
)
from fixtures.pageserver.http import (
ImportPgdataIdemptencyKey,
)
from fixtures.pageserver.utils import wait_for_upload_queue_empty
from fixtures.remote_storage import RemoteStorageKind
from fixtures.utils import human_bytes, run_only_on_default_postgres, wait_until
from werkzeug.wrappers.response import Response
from fixtures.utils import human_bytes, wait_until
if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Any
from fixtures.pageserver.http import PageserverHttpClient
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
GLOBAL_LRU_LOG_LINE = "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy"
@@ -174,7 +164,6 @@ class EvictionEnv:
min_avail_bytes,
mock_behavior,
eviction_order: EvictionOrder,
wait_logical_size: bool = True,
):
"""
Starts pageserver up with mocked statvfs setup. The startup is
@@ -212,12 +201,11 @@ class EvictionEnv:
pageserver.start()
# we now do initial logical size calculation on startup, which on debug builds can fight with disk usage based eviction
if wait_logical_size:
for tenant_id, timeline_id in self.timelines:
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
if tenant_ps is not None:
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
for tenant_id, timeline_id in self.timelines:
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
if tenant_ps is not None:
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
def statvfs_called():
pageserver.assert_log_contains(".*running mocked statvfs.*")
@@ -894,121 +882,3 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv):
assert total_size - post_eviction_total_size >= evict_bytes, (
"we requested at least evict_bytes worth of free space"
)
@run_only_on_default_postgres(reason="PG version is irrelevant here")
def test_import_timeline_disk_pressure_eviction(
neon_env_builder: NeonEnvBuilder,
vanilla_pg: VanillaPostgres,
make_httpserver: HTTPServer,
pg_bin: PgBin,
):
"""
TODO
"""
# Set up mock control plane HTTP server to listen for import completions
import_completion_signaled = Event()
def handler(request: Request) -> Response:
log.info(f"control plane /import_complete request: {request.json}")
import_completion_signaled.set()
return Response(json.dumps({}), status=200)
cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(
"/storage/api/v1/import_complete", method="PUT"
).respond_with_handler(handler)
# Plug the cplane mock in
neon_env_builder.control_plane_hooks_api = (
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
)
# The import will specifiy a local filesystem path mocking remote storage
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
vanilla_pg.start()
target_relblock_size = 1024 * 1024 * 128
populate_vanilla_pg(vanilla_pg, target_relblock_size)
vanilla_pg.stop()
env = neon_env_builder.init_configs()
env.start()
importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket"
mock_import_bucket(vanilla_pg, importbucket_path)
tenant_id = TenantId.generate()
timeline_id = TimelineId.generate()
idempotency = ImportPgdataIdemptencyKey.random()
eviction_env = EvictionEnv(
timelines=[(tenant_id, timeline_id)],
neon_env=env,
pageserver_http=env.pageserver.http_client(),
layer_size=5 * 1024 * 1024, # Doesn't apply here
pg_bin=pg_bin, # Not used here
pgbench_init_lsns={}, # Not used here
)
# Pause before delivering the final notification to storcon.
# This keeps the import in progress.
failpoint_name = "import-timeline-pre-success-notify-pausable"
env.pageserver.add_persistent_failpoint(failpoint_name, "pause")
env.storage_controller.tenant_create(tenant_id)
env.storage_controller.timeline_create(
tenant_id,
{
"new_timeline_id": str(timeline_id),
"import_pgdata": {
"idempotency_key": str(idempotency),
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
},
},
)
def hit_failpoint():
log.info("Checking log for pattern...")
try:
assert env.pageserver.log_contains(f".*at failpoint {failpoint_name}.*")
except Exception:
log.exception("Failed to find pattern in log")
raise
wait_until(hit_failpoint)
assert not import_completion_signaled.is_set()
env.pageserver.stop()
total_size, _, _ = eviction_env.timelines_du(env.pageserver)
blocksize = 512
total_blocks = (total_size + (blocksize - 1)) // blocksize
eviction_env.pageserver_start_with_disk_usage_eviction(
env.pageserver,
period="1s",
max_usage_pct=33,
min_avail_bytes=0,
mock_behavior={
"type": "Success",
"blocksize": blocksize,
"total_blocks": total_blocks,
# Only count layer files towards used bytes in the mock_statvfs.
# This avoids accounting for metadata files & tenant conf in the tests.
"name_filter": ".*__.*",
},
eviction_order=EvictionOrder.RELATIVE_ORDER_SPARE,
wait_logical_size=False,
)
wait_until(lambda: env.pageserver.assert_log_contains(".*disk usage pressure relieved"))
env.pageserver.clear_persistent_failpoint(failpoint_name)
def cplane_notified():
assert import_completion_signaled.is_set()
wait_until(cplane_notified)
env.pageserver.allowed_errors.append(r".* running disk usage based eviction due to pressure.*")

View File

@@ -1,10 +1,7 @@
import base64
import concurrent.futures
import json
import random
import threading
import time
from enum import Enum, StrEnum
from enum import Enum
from pathlib import Path
from threading import Event
@@ -12,22 +9,9 @@ import psycopg2
import psycopg2.errors
import pytest
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
from fixtures.fast_import import (
FastImport,
mock_import_bucket,
populate_vanilla_pg,
validate_import_from_vanilla_pg,
)
from fixtures.fast_import import FastImport
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnvBuilder,
PageserverImportConfig,
PgBin,
PgProtocol,
StorageControllerApiException,
StorageControllerMigrationConfig,
VanillaPostgres,
)
from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PgProtocol, VanillaPostgres
from fixtures.pageserver.http import (
ImportPgdataIdemptencyKey,
)
@@ -65,6 +49,24 @@ smoke_params = [
]
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
"""
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
"""
assert not vanilla_pg.is_running()
path.mkdir()
# what cplane writes before scheduling fast_import
specpath = path / "spec.json"
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
# what fast_import writes
vanilla_pg.pgdatadir.rename(path / "pgdata")
statusdir = path / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
@skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data")
@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
def test_pgdata_import_smoke(
@@ -119,6 +121,10 @@ def test_pgdata_import_smoke(
# Put data in vanilla pg
#
vanilla_pg.start()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
log.info("create relblock data")
if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE:
target_relblock_size = stripe_size * 8192
elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD:
@@ -129,8 +135,45 @@ def test_pgdata_import_smoke(
else:
raise ValueError
vanilla_pg.start()
rows_inserted = populate_vanilla_pg(vanilla_pg, target_relblock_size)
# fillfactor so we don't need to produce that much data
# 900 byte per row is > 10% => 1 row per page
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
nrows = 0
while True:
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
log.info(
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
)
if relblock_size >= target_relblock_size:
break
addrows = int((target_relblock_size - relblock_size) // 8192)
assert addrows >= 1, "forward progress"
vanilla_pg.safe_psql(
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
)
nrows += addrows
expect_nrows = nrows
expect_sum = (
(nrows) * (nrows + 1) // 2
) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n
def validate_vanilla_equivalence(ep):
# TODO: would be nicer to just compare pgdump
# Enable IO concurrency for batching on large sequential scan, to avoid making
# this test unnecessarily onerous on CPU. Especially on debug mode, it's still
# pretty onerous though, so increase statement_timeout to avoid timeouts.
assert ep.safe_psql_many(
[
"set effective_io_concurrency=32;",
"SET statement_timeout='300s';",
"select count(*), sum(data::bigint)::bigint from t",
]
) == [[], [], [(expect_nrows, expect_sum)]]
validate_vanilla_equivalence(vanilla_pg)
vanilla_pg.stop()
#
@@ -221,14 +264,14 @@ def test_pgdata_import_smoke(
config_lines=ep_config,
)
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
validate_vanilla_equivalence(ro_endpoint)
# ensure the import survives restarts
ro_endpoint.stop()
env.pageserver.stop(immediate=True)
env.pageserver.start()
ro_endpoint.start()
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
validate_vanilla_equivalence(ro_endpoint)
#
# validate the layer files in each shard only have the shard-specific data
@@ -268,7 +311,7 @@ def test_pgdata_import_smoke(
child_workload = workload.branch(timeline_id=child_timeline_id, branch_name="br-tip")
child_workload.validate()
validate_import_from_vanilla_pg(child_workload.endpoint(), rows_inserted)
validate_vanilla_equivalence(child_workload.endpoint())
# ... at the initdb lsn
_ = env.create_branch(
@@ -283,7 +326,7 @@ def test_pgdata_import_smoke(
tenant_id=tenant_id,
config_lines=ep_config,
)
validate_import_from_vanilla_pg(br_initdb_endpoint, rows_inserted)
validate_vanilla_equivalence(br_initdb_endpoint)
with pytest.raises(psycopg2.errors.UndefinedTable):
br_initdb_endpoint.safe_psql(f"select * from {workload.table}")
@@ -370,12 +413,8 @@ def test_import_completion_on_restart(
@run_only_on_default_postgres(reason="PG version is irrelevant here")
@pytest.mark.parametrize("action", ["restart", "delete"])
def test_import_respects_timeline_lifecycle(
neon_env_builder: NeonEnvBuilder,
vanilla_pg: VanillaPostgres,
make_httpserver: HTTPServer,
action: str,
def test_import_respects_tenant_shutdown(
neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer
):
"""
Validate that importing timelines respect the usual timeline life cycle:
@@ -443,265 +482,16 @@ def test_import_respects_timeline_lifecycle(
wait_until(hit_failpoint)
assert not import_completion_signaled.is_set()
if action == "restart":
# Restart the pageserver while an import job is in progress.
# This clears the failpoint and we expect that the import starts up afresh
# after the restart and eventually completes.
env.pageserver.stop()
env.pageserver.start()
# Restart the pageserver while an import job is in progress.
# This clears the failpoint and we expect that the import starts up afresh
# after the restart and eventually completes.
env.pageserver.stop()
env.pageserver.start()
def cplane_notified():
assert import_completion_signaled.is_set()
def cplane_notified():
assert import_completion_signaled.is_set()
wait_until(cplane_notified)
elif action == "delete":
status = env.storage_controller.pageserver_api().timeline_delete(tenant_id, timeline_id)
assert status == 200
timeline_path = env.pageserver.timeline_dir(tenant_id, timeline_id)
assert not timeline_path.exists(), "Timeline dir exists after deletion"
shard_zero = TenantShardId(tenant_id, 0, 0)
location = env.storage_controller.inspect(shard_zero)
assert location is not None
generation = location[0]
with pytest.raises(StorageControllerApiException, match="not found"):
env.storage_controller.import_status(shard_zero, timeline_id, generation)
else:
raise RuntimeError(f"{action} param not recognized")
@skip_in_debug_build("Validation query takes too long in debug builds")
def test_import_chaos(
neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer
):
"""
Perform a timeline import while injecting chaos in the environment.
We expect that the import completes eventually, that it passes validation and
the resulting timeline can be written to.
"""
TARGET_RELBOCK_SIZE = 512 * 1024 * 1024 # 512 MiB
ALLOWED_IMPORT_RUNTIME = 90 # seconds
SHARD_COUNT = 4
neon_env_builder.num_pageservers = SHARD_COUNT
neon_env_builder.pageserver_import_config = PageserverImportConfig(
import_job_concurrency=1,
import_job_soft_size_limit=64 * 1024,
import_job_checkpoint_threshold=4,
)
# Set up mock control plane HTTP server to listen for import completions
import_completion_signaled = Event()
# There's some Python magic at play here. A list can be updated from the
# handler thread, but an optional cannot. Hence, use a list with one element.
import_error = []
def handler(request: Request) -> Response:
assert request.json is not None
body = request.json
if "error" in body:
if body["error"]:
import_error.append(body["error"])
log.info(f"control plane /import_complete request: {request.json}")
import_completion_signaled.set()
return Response(json.dumps({}), status=200)
cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(
"/storage/api/v1/import_complete", method="PUT"
).respond_with_handler(handler)
# Plug the cplane mock in
neon_env_builder.control_plane_hooks_api = (
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
)
# The import will specifiy a local filesystem path mocking remote storage
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
vanilla_pg.start()
inserted_rows = populate_vanilla_pg(vanilla_pg, TARGET_RELBOCK_SIZE)
vanilla_pg.stop()
env = neon_env_builder.init_configs()
env.start()
# Pause after every import task to extend the test runtime and allow
# for more chaos injection.
for ps in env.pageservers:
ps.add_persistent_failpoint("import-task-complete-pausable", "sleep(5)")
env.storage_controller.allowed_errors.extend(
[
# The shard might have moved or the pageserver hosting the shard restarted
".*Call to node.*management API.*failed.*",
# Migrations have their time outs set to 0
".*Timed out after.*downloading layers.*",
".*Failed to prepare by downloading layers.*",
# The test may kill the storage controller or pageservers
".*request was dropped before completing.*",
]
)
for ps in env.pageservers:
ps.allowed_errors.extend(
[
# We might re-write a layer in a different generation if the import
# needs to redo some of the progress since not each job is checkpointed.
".*was unlinked but was not dangling.*",
# The test may kill the storage controller or pageservers
".*request was dropped before completing.*",
# Test can SIGTERM pageserver while it is downloading
".*removing local file.*temp_download.*",
".*Failed to flush heatmap.*",
# Test can SIGTERM the storage controller while pageserver
# is attempting to upcall.
".*storage controller upcall failed.*timeline_import_status.*",
# TODO(vlad): TenantManager::reset_tenant returns a blanked anyhow error.
# It should return ResourceUnavailable or something that doesn't error log.
".*activate_post_import.*InternalServerError.*tenant map is shutting down.*",
# TODO(vlad): How can this happen?
".*Failed to download a remote file: deserialize index part file.*",
".*Cancelled request finished with an error.*",
]
)
importbucket_path = neon_env_builder.repo_dir / "test_import_chaos_bucket"
mock_import_bucket(vanilla_pg, importbucket_path)
tenant_id = TenantId.generate()
timeline_id = TimelineId.generate()
idempotency = ImportPgdataIdemptencyKey.random()
env.storage_controller.tenant_create(
tenant_id, shard_count=SHARD_COUNT, placement_policy={"Attached": 1}
)
env.storage_controller.reconcile_until_idle()
env.storage_controller.timeline_create(
tenant_id,
{
"new_timeline_id": str(timeline_id),
"import_pgdata": {
"idempotency_key": str(idempotency),
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
},
},
)
def chaos(stop_chaos: threading.Event):
class ChaosType(StrEnum):
MIGRATE_SHARD = "migrate_shard"
RESTART_IMMEDIATE = "restart_immediate"
RESTART = "restart"
STORCON_RESTART_IMMEDIATE = "storcon_restart_immediate"
while not stop_chaos.is_set():
chaos_type = random.choices(
population=[
ChaosType.MIGRATE_SHARD,
ChaosType.RESTART,
ChaosType.RESTART_IMMEDIATE,
ChaosType.STORCON_RESTART_IMMEDIATE,
],
weights=[0.25, 0.25, 0.25, 0.25],
k=1,
)[0]
try:
if chaos_type == ChaosType.MIGRATE_SHARD:
target_shard_number = random.randint(0, SHARD_COUNT - 1)
target_shard = TenantShardId(tenant_id, target_shard_number, SHARD_COUNT)
placements = env.storage_controller.get_tenants_placement()
log.info(f"{placements=}")
target_ps = placements[str(target_shard)]["intent"]["attached"]
if len(placements[str(target_shard)]["intent"]["secondary"]) == 0:
dest_ps = None
else:
dest_ps = placements[str(target_shard)]["intent"]["secondary"][0]
if target_ps is None or dest_ps is None:
continue
config = StorageControllerMigrationConfig(
secondary_warmup_timeout="0s",
secondary_download_request_timeout="0s",
prewarm=False,
)
env.storage_controller.tenant_shard_migrate(target_shard, dest_ps, config)
log.info(
f"CHAOS: Migrating shard {target_shard} from pageserver {target_ps} to {dest_ps}"
)
elif chaos_type == ChaosType.RESTART_IMMEDIATE:
target_ps = random.choice(env.pageservers)
log.info(f"CHAOS: Immediate restart of pageserver {target_ps.id}")
target_ps.stop(immediate=True)
target_ps.start()
elif chaos_type == ChaosType.RESTART:
target_ps = random.choice(env.pageservers)
log.info(f"CHAOS: Normal restart of pageserver {target_ps.id}")
target_ps.stop(immediate=False)
target_ps.start()
elif chaos_type == ChaosType.STORCON_RESTART_IMMEDIATE:
log.info("CHAOS: Immediate restart of storage controller")
env.storage_controller.stop(immediate=True)
env.storage_controller.start()
except Exception as e:
log.warning(f"CHAOS: Error during chaos operation {chaos_type}: {e}")
# Sleep before next chaos event
time.sleep(1)
log.info("Chaos injector stopped")
def wait_for_import_completion():
start = time.time()
done = import_completion_signaled.wait(ALLOWED_IMPORT_RUNTIME)
if not done:
raise TimeoutError(f"Import did not signal completion within {ALLOWED_IMPORT_RUNTIME}")
end = time.time()
log.info(f"Import completion signalled after {end - start}s {import_error=}")
if import_error:
raise RuntimeError(f"Import error: {import_error}")
with concurrent.futures.ThreadPoolExecutor() as executor:
stop_chaos = threading.Event()
wait_for_import_completion_fut = executor.submit(wait_for_import_completion)
chaos_fut = executor.submit(chaos, stop_chaos)
try:
wait_for_import_completion_fut.result()
except Exception as e:
raise e
finally:
stop_chaos.set()
chaos_fut.result()
import_branch_name = "imported"
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)
endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id)
# Validate the imported data is legit
validate_import_from_vanilla_pg(endpoint, inserted_rows)
endpoint.stop()
# Validate writes
workload = Workload(env, tenant_id, timeline_id, branch_name=import_branch_name)
workload.init()
workload.write_rows(64)
workload.validate()
wait_until(cplane_notified)
def test_fast_import_with_pageserver_ingest(

View File

@@ -20,9 +20,6 @@ from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind
from fixtures.utils import query_scalar, wait_until
@pytest.mark.skip(
reason="We won't create future layers any more after https://github.com/neondatabase/neon/pull/10548"
)
@pytest.mark.parametrize(
"attach_mode",
["default_generation", "same_generation"],

View File

@@ -4158,12 +4158,17 @@ def test_storcon_create_delete_sk_down(
env.storage_controller.stop()
env.storage_controller.start()
with env.endpoints.create("main", tenant_id=tenant_id) as ep:
config_lines = [
"neon.safekeeper_proto_version = 3",
]
with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep:
# endpoint should start.
ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3])
ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)")
with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep:
with env.endpoints.create(
"child_of_main", tenant_id=tenant_id, config_lines=config_lines
) as ep:
# endpoint should start.
ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3])
ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)")
@@ -4192,10 +4197,10 @@ def test_storcon_create_delete_sk_down(
# ensure the safekeeper deleted the timeline
def timeline_deleted_on_active_sks():
env.safekeepers[0].assert_log_contains(
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
)
env.safekeepers[2].assert_log_contains(
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
)
wait_until(timeline_deleted_on_active_sks)
@@ -4210,7 +4215,7 @@ def test_storcon_create_delete_sk_down(
# ensure that there is log msgs for the third safekeeper too
def timeline_deleted_on_sk():
env.safekeepers[1].assert_log_contains(
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
)
wait_until(timeline_deleted_on_sk)
@@ -4244,12 +4249,17 @@ def test_storcon_few_sk(
env.safekeepers[0].assert_log_contains(f"creating new timeline {tenant_id}/{timeline_id}")
with env.endpoints.create("main", tenant_id=tenant_id) as ep:
config_lines = [
"neon.safekeeper_proto_version = 3",
]
with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep:
# endpoint should start.
ep.start(safekeeper_generation=1, safekeepers=safekeeper_list)
ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)")
with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep:
with env.endpoints.create(
"child_of_main", tenant_id=tenant_id, config_lines=config_lines
) as ep:
# endpoint should start.
ep.start(safekeeper_generation=1, safekeepers=safekeeper_list)
ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)")

View File

@@ -10,7 +10,6 @@ from queue import Empty, Queue
from threading import Barrier
import pytest
import requests
from fixtures.common_types import Lsn, TimelineArchivalState, TimelineId
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
@@ -402,25 +401,8 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots
"earlier", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_pipe
)
snapshot_branchpoint_old = TimelineId.generate()
env.storage_controller.timeline_create(
env.initial_tenant,
{
"new_timeline_id": str(snapshot_branchpoint_old),
"ancestor_start_lsn": str(branchpoint_y),
"ancestor_timeline_id": str(env.initial_timeline),
"read_only": True,
},
)
sk = env.safekeepers[0]
assert sk
with pytest.raises(requests.exceptions.HTTPError, match="Not Found"):
sk.http_client().timeline_status(
tenant_id=env.initial_tenant, timeline_id=snapshot_branchpoint_old
)
env.neon_cli.mappings_map_branch(
"snapshot_branchpoint_old", env.initial_tenant, snapshot_branchpoint_old
snapshot_branchpoint_old = env.create_branch(
"snapshot_branchpoint_old", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_y
)
snapshot_branchpoint = env.create_branch(

View File

@@ -2012,7 +2012,10 @@ def test_explicit_timeline_creation(neon_env_builder: NeonEnvBuilder):
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
ep = env.endpoints.create("main")
config_lines = [
"neon.safekeeper_proto_version = 3",
]
ep = env.endpoints.create("main", config_lines=config_lines)
# expected to fail because timeline is not created on safekeepers
with pytest.raises(Exception, match=r".*timed out.*"):
@@ -2040,7 +2043,10 @@ def test_explicit_timeline_creation_storcon(neon_env_builder: NeonEnvBuilder):
}
env = neon_env_builder.init_start()
ep = env.endpoints.create("main")
config_lines = [
"neon.safekeeper_proto_version = 3",
]
ep = env.endpoints.create("main", config_lines=config_lines)
# endpoint should start.
ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3])

View File

@@ -637,7 +637,10 @@ async def quorum_sanity_single(
# create timeline on `members_sks`
Safekeeper.create_timeline(tenant_id, timeline_id, env.pageservers[0], mconf, members_sks)
ep = env.endpoints.create(branch_name)
config_lines = [
"neon.safekeeper_proto_version = 3",
]
ep = env.endpoints.create(branch_name, config_lines=config_lines)
ep.start(safekeeper_generation=1, safekeepers=compute_sks_ids)
ep.safe_psql("create table t(key int, value text)")

View File

@@ -41,8 +41,10 @@ env_logger = { version = "0.11" }
fail = { version = "0.5", default-features = false, features = ["failpoints"] }
form_urlencoded = { version = "1" }
futures-channel = { version = "0.3", features = ["sink"] }
futures-core = { version = "0.3" }
futures-executor = { version = "0.3" }
futures-io = { version = "0.3" }
futures-task = { version = "0.3", default-features = false, features = ["std"] }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
@@ -72,6 +74,7 @@ num-traits = { version = "0.2", features = ["i128", "libm"] }
once_cell = { version = "1" }
p256 = { version = "0.13", features = ["jwk"] }
parquet = { version = "53", default-features = false, features = ["zstd"] }
percent-encoding = { version = "2" }
prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] }
rand = { version = "0.8", features = ["small_rng"] }
regex = { version = "1" }