mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-14 11:40:38 +00:00
Compare commits
15 Commits
refactor_i
...
problame/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85445cde14 | ||
|
|
8625466144 | ||
|
|
1ab0cfc8cb | ||
|
|
ca469be1cf | ||
|
|
286f34dfce | ||
|
|
f290b27378 | ||
|
|
4cd18fcebd | ||
|
|
4c29e0594e | ||
|
|
3c56a4dd18 | ||
|
|
316309c85b | ||
|
|
e09bb9974c | ||
|
|
5289f341ce | ||
|
|
683ec2417c | ||
|
|
a76a503b8b | ||
|
|
92bc2bb132 |
3
.github/workflows/build_and_test.yml
vendored
3
.github/workflows/build_and_test.yml
vendored
@@ -404,7 +404,7 @@ jobs:
|
||||
uses: ./.github/actions/save-coverage-data
|
||||
|
||||
regress-tests:
|
||||
needs: [ check-permissions, build-neon ]
|
||||
needs: [ check-permissions, build-neon, tag ]
|
||||
runs-on: [ self-hosted, gen3, large ]
|
||||
container:
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
|
||||
@@ -436,6 +436,7 @@ jobs:
|
||||
env:
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
|
||||
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
|
||||
BUILD_TAG: ${{ needs.tag.outputs.build-tag }}
|
||||
|
||||
- name: Merge and upload coverage data
|
||||
if: matrix.build_type == 'debug' && matrix.pg_version == 'v14'
|
||||
|
||||
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -1126,6 +1126,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"bytes",
|
||||
"cfg-if",
|
||||
"chrono",
|
||||
"clap",
|
||||
@@ -3010,6 +3011,7 @@ dependencies = [
|
||||
"serde_with",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"thiserror",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
]
|
||||
@@ -3504,6 +3506,7 @@ dependencies = [
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-native-tls",
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
|
||||
@@ -714,6 +714,24 @@ RUN wget https://github.com/pksunkara/pgx_ulid/archive/refs/tags/v0.1.3.tar.gz -
|
||||
cargo pgrx install --release && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "wal2json-build"
|
||||
# Compile "wal2json" extension
|
||||
#
|
||||
#########################################################################################
|
||||
|
||||
FROM build-deps AS wal2json-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
ENV PATH "/usr/local/pgsql/bin/:$PATH"
|
||||
RUN wget https://github.com/eulerto/wal2json/archive/refs/tags/wal2json_2_5.tar.gz && \
|
||||
echo "b516653575541cf221b99cf3f8be9b6821f6dbcfc125675c85f35090f824f00e wal2json_2_5.tar.gz" | sha256sum --check && \
|
||||
mkdir wal2json-src && cd wal2json-src && tar xvzf ../wal2json_2_5.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/wal2json.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "neon-pg-ext-build"
|
||||
@@ -750,6 +768,7 @@ COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
|
||||
COPY pgxn/ pgxn/
|
||||
|
||||
RUN make -j $(getconf _NPROCESSORS_ONLN) \
|
||||
|
||||
@@ -38,3 +38,4 @@ toml_edit.workspace = true
|
||||
remote_storage = { version = "0.1", path = "../libs/remote_storage/" }
|
||||
vm_monitor = { version = "0.1", path = "../libs/vm_monitor/" }
|
||||
zstd = "0.12.4"
|
||||
bytes = "1.0"
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
//! -C 'postgresql://cloud_admin@localhost/postgres' \
|
||||
//! -S /var/db/postgres/specs/current.json \
|
||||
//! -b /usr/local/bin/postgres \
|
||||
//! -r {"bucket": "neon-dev-extensions-eu-central-1", "region": "eu-central-1"}
|
||||
//! -r http://pg-ext-s3-gateway
|
||||
//! ```
|
||||
//!
|
||||
use std::collections::HashMap;
|
||||
@@ -51,7 +51,7 @@ use compute_api::responses::ComputeStatus;
|
||||
|
||||
use compute_tools::compute::{ComputeNode, ComputeState, ParsedSpec};
|
||||
use compute_tools::configurator::launch_configurator;
|
||||
use compute_tools::extension_server::{get_pg_version, init_remote_storage};
|
||||
use compute_tools::extension_server::get_pg_version;
|
||||
use compute_tools::http::api::launch_http_server;
|
||||
use compute_tools::logger::*;
|
||||
use compute_tools::monitor::launch_monitor;
|
||||
@@ -60,7 +60,7 @@ use compute_tools::spec::*;
|
||||
|
||||
// this is an arbitrary build tag. Fine as a default / for testing purposes
|
||||
// in-case of not-set environment var
|
||||
const BUILD_TAG_DEFAULT: &str = "5670669815";
|
||||
const BUILD_TAG_DEFAULT: &str = "latest";
|
||||
|
||||
fn main() -> Result<()> {
|
||||
init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
|
||||
@@ -74,10 +74,18 @@ fn main() -> Result<()> {
|
||||
let pgbin_default = String::from("postgres");
|
||||
let pgbin = matches.get_one::<String>("pgbin").unwrap_or(&pgbin_default);
|
||||
|
||||
let remote_ext_config = matches.get_one::<String>("remote-ext-config");
|
||||
let ext_remote_storage = remote_ext_config.map(|x| {
|
||||
init_remote_storage(x).expect("cannot initialize remote extension storage from config")
|
||||
});
|
||||
let ext_remote_storage = matches
|
||||
.get_one::<String>("remote-ext-config")
|
||||
// Compatibility hack: if the control plane specified any remote-ext-config
|
||||
// use the default value for extension storage proxy gateway.
|
||||
// Remove this once the control plane is updated to pass the gateway URL
|
||||
.map(|conf| {
|
||||
if conf.starts_with("http") {
|
||||
conf.trim_end_matches('/')
|
||||
} else {
|
||||
"http://pg-ext-s3-gateway"
|
||||
}
|
||||
});
|
||||
|
||||
let http_port = *matches
|
||||
.get_one::<u16>("http-port")
|
||||
@@ -198,7 +206,7 @@ fn main() -> Result<()> {
|
||||
live_config_allowed,
|
||||
state: Mutex::new(new_state),
|
||||
state_changed: Condvar::new(),
|
||||
ext_remote_storage,
|
||||
ext_remote_storage: ext_remote_storage.map(|s| s.to_string()),
|
||||
ext_download_progress: RwLock::new(HashMap::new()),
|
||||
build_tag,
|
||||
};
|
||||
|
||||
@@ -25,7 +25,7 @@ use compute_api::responses::{ComputeMetrics, ComputeStatus};
|
||||
use compute_api::spec::{ComputeMode, ComputeSpec};
|
||||
use utils::measured_stream::MeasuredReader;
|
||||
|
||||
use remote_storage::{DownloadError, GenericRemoteStorage, RemotePath};
|
||||
use remote_storage::{DownloadError, RemotePath};
|
||||
|
||||
use crate::checker::create_availability_check_data;
|
||||
use crate::pg_helpers::*;
|
||||
@@ -59,8 +59,8 @@ pub struct ComputeNode {
|
||||
pub state: Mutex<ComputeState>,
|
||||
/// `Condvar` to allow notifying waiters about state changes.
|
||||
pub state_changed: Condvar,
|
||||
/// the S3 bucket that we search for extensions in
|
||||
pub ext_remote_storage: Option<GenericRemoteStorage>,
|
||||
/// the address of extension storage proxy gateway
|
||||
pub ext_remote_storage: Option<String>,
|
||||
// key: ext_archive_name, value: started download time, download_completed?
|
||||
pub ext_download_progress: RwLock<HashMap<String, (DateTime<Utc>, bool)>>,
|
||||
pub build_tag: String,
|
||||
@@ -957,12 +957,12 @@ LIMIT 100",
|
||||
real_ext_name: String,
|
||||
ext_path: RemotePath,
|
||||
) -> Result<u64, DownloadError> {
|
||||
let remote_storage = self
|
||||
.ext_remote_storage
|
||||
.as_ref()
|
||||
.ok_or(DownloadError::BadInput(anyhow::anyhow!(
|
||||
"Remote extensions storage is not configured",
|
||||
)))?;
|
||||
let ext_remote_storage =
|
||||
self.ext_remote_storage
|
||||
.as_ref()
|
||||
.ok_or(DownloadError::BadInput(anyhow::anyhow!(
|
||||
"Remote extensions storage is not configured",
|
||||
)))?;
|
||||
|
||||
let ext_archive_name = ext_path.object_name().expect("bad path");
|
||||
|
||||
@@ -1018,7 +1018,7 @@ LIMIT 100",
|
||||
let download_size = extension_server::download_extension(
|
||||
&real_ext_name,
|
||||
&ext_path,
|
||||
remote_storage,
|
||||
ext_remote_storage,
|
||||
&self.pgbin,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -71,18 +71,16 @@ More specifically, here is an example ext_index.json
|
||||
}
|
||||
}
|
||||
*/
|
||||
use anyhow::Context;
|
||||
use anyhow::{self, Result};
|
||||
use anyhow::{bail, Context};
|
||||
use bytes::Bytes;
|
||||
use compute_api::spec::RemoteExtSpec;
|
||||
use regex::Regex;
|
||||
use remote_storage::*;
|
||||
use serde_json;
|
||||
use std::io::Read;
|
||||
use std::num::NonZeroUsize;
|
||||
use reqwest::StatusCode;
|
||||
use std::path::Path;
|
||||
use std::str;
|
||||
use tar::Archive;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::info;
|
||||
use tracing::log::warn;
|
||||
use zstd::stream::read::Decoder;
|
||||
@@ -138,23 +136,31 @@ fn parse_pg_version(human_version: &str) -> &str {
|
||||
pub async fn download_extension(
|
||||
ext_name: &str,
|
||||
ext_path: &RemotePath,
|
||||
remote_storage: &GenericRemoteStorage,
|
||||
ext_remote_storage: &str,
|
||||
pgbin: &str,
|
||||
) -> Result<u64> {
|
||||
info!("Download extension {:?} from {:?}", ext_name, ext_path);
|
||||
let mut download = remote_storage.download(ext_path).await?;
|
||||
let mut download_buffer = Vec::new();
|
||||
download
|
||||
.download_stream
|
||||
.read_to_end(&mut download_buffer)
|
||||
.await?;
|
||||
|
||||
// TODO add retry logic
|
||||
let download_buffer =
|
||||
match download_extension_tar(ext_remote_storage, &ext_path.to_string()).await {
|
||||
Ok(buffer) => buffer,
|
||||
Err(error_message) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"error downloading extension {:?}: {:?}",
|
||||
ext_name,
|
||||
error_message
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let download_size = download_buffer.len() as u64;
|
||||
info!("Download size {:?}", download_size);
|
||||
// it's unclear whether it is more performant to decompress into memory or not
|
||||
// TODO: decompressing into memory can be avoided
|
||||
let mut decoder = Decoder::new(download_buffer.as_slice())?;
|
||||
let mut decompress_buffer = Vec::new();
|
||||
decoder.read_to_end(&mut decompress_buffer)?;
|
||||
let mut archive = Archive::new(decompress_buffer.as_slice());
|
||||
let decoder = Decoder::new(download_buffer.as_ref())?;
|
||||
let mut archive = Archive::new(decoder);
|
||||
|
||||
let unzip_dest = pgbin
|
||||
.strip_suffix("/bin/postgres")
|
||||
.expect("bad pgbin")
|
||||
@@ -222,29 +228,32 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) {
|
||||
}
|
||||
}
|
||||
|
||||
// This function initializes the necessary structs to use remote storage
|
||||
pub fn init_remote_storage(remote_ext_config: &str) -> anyhow::Result<GenericRemoteStorage> {
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct RemoteExtJson {
|
||||
bucket: String,
|
||||
region: String,
|
||||
endpoint: Option<String>,
|
||||
prefix: Option<String>,
|
||||
}
|
||||
let remote_ext_json = serde_json::from_str::<RemoteExtJson>(remote_ext_config)?;
|
||||
// Do request to extension storage proxy, i.e.
|
||||
// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst
|
||||
// using HHTP GET
|
||||
// and return the response body as bytes
|
||||
//
|
||||
async fn download_extension_tar(ext_remote_storage: &str, ext_path: &str) -> Result<Bytes> {
|
||||
let uri = format!("{}/{}", ext_remote_storage, ext_path);
|
||||
|
||||
let config = S3Config {
|
||||
bucket_name: remote_ext_json.bucket,
|
||||
bucket_region: remote_ext_json.region,
|
||||
prefix_in_bucket: remote_ext_json.prefix,
|
||||
endpoint: remote_ext_json.endpoint,
|
||||
concurrency_limit: NonZeroUsize::new(100).expect("100 != 0"),
|
||||
max_keys_per_list_response: None,
|
||||
};
|
||||
let config = RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::AwsS3(config),
|
||||
};
|
||||
GenericRemoteStorage::from_config(&config)
|
||||
info!("Download extension {:?} from uri {:?}", ext_path, uri);
|
||||
|
||||
let resp = reqwest::get(uri).await?;
|
||||
|
||||
match resp.status() {
|
||||
StatusCode::OK => match resp.bytes().await {
|
||||
Ok(resp) => {
|
||||
info!("Download extension {:?} completed successfully", ext_path);
|
||||
Ok(resp)
|
||||
}
|
||||
Err(e) => bail!("could not deserialize remote extension response: {}", e),
|
||||
},
|
||||
StatusCode::SERVICE_UNAVAILABLE => bail!("remote extension is temporarily unavailable"),
|
||||
_ => bail!(
|
||||
"unexpected remote extension response status code: {}",
|
||||
resp.status()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -123,7 +123,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
}
|
||||
|
||||
// download extension files from S3 on demand
|
||||
// download extension files from remote extension storage on demand
|
||||
(&Method::POST, route) if route.starts_with("/extension_server/") => {
|
||||
info!("serving {:?} POST request", route);
|
||||
info!("req.uri {:?}", req.uri());
|
||||
|
||||
@@ -687,6 +687,9 @@ pub fn handle_extension_neon(client: &mut Client) -> Result<()> {
|
||||
info!("create neon extension with query: {}", query);
|
||||
client.simple_query(query)?;
|
||||
|
||||
query = "UPDATE pg_extension SET extrelocatable = true WHERE extname = 'neon'";
|
||||
client.simple_query(query)?;
|
||||
|
||||
query = "ALTER EXTENSION neon SET SCHEMA neon";
|
||||
info!("alter neon extension schema with query: {}", query);
|
||||
client.simple_query(query)?;
|
||||
|
||||
@@ -1252,7 +1252,7 @@ fn cli() -> Command {
|
||||
let remote_ext_config_args = Arg::new("remote-ext-config")
|
||||
.long("remote-ext-config")
|
||||
.num_args(1)
|
||||
.help("Configure the S3 bucket that we search for extensions in.")
|
||||
.help("Configure the remote extensions storage proxy gateway to request for extensions.")
|
||||
.required(false);
|
||||
|
||||
let lsn_arg = Arg::new("lsn")
|
||||
|
||||
@@ -45,6 +45,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use compute_api::spec::RemoteExtSpec;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utils::id::{NodeId, TenantId, TimelineId};
|
||||
|
||||
@@ -476,6 +477,18 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
// check for file remote_extensions_spec.json
|
||||
// if it is present, read it and pass to compute_ctl
|
||||
let remote_extensions_spec_path = self.endpoint_path().join("remote_extensions_spec.json");
|
||||
let remote_extensions_spec = std::fs::File::open(remote_extensions_spec_path);
|
||||
let remote_extensions: Option<RemoteExtSpec>;
|
||||
|
||||
if let Ok(spec_file) = remote_extensions_spec {
|
||||
remote_extensions = serde_json::from_reader(spec_file).ok();
|
||||
} else {
|
||||
remote_extensions = None;
|
||||
};
|
||||
|
||||
// Create spec file
|
||||
let spec = ComputeSpec {
|
||||
skip_pg_catalog_updates: self.skip_pg_catalog_updates,
|
||||
@@ -497,7 +510,7 @@ impl Endpoint {
|
||||
pageserver_connstring: Some(pageserver_connstring),
|
||||
safekeeper_connstrings,
|
||||
storage_auth_token: auth_token.clone(),
|
||||
remote_extensions: None,
|
||||
remote_extensions,
|
||||
};
|
||||
let spec_path = self.endpoint_path().join("spec.json");
|
||||
std::fs::write(spec_path, serde_json::to_string_pretty(&spec)?)?;
|
||||
|
||||
@@ -14,7 +14,6 @@ use pageserver_api::models::{
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use utils::{
|
||||
generation::Generation,
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
@@ -93,6 +92,22 @@ pub fn migrate_tenant(
|
||||
// Get a new generation
|
||||
let attachment_service = AttachmentService::from_env(env);
|
||||
|
||||
fn build_location_config(
|
||||
mode: LocationConfigMode,
|
||||
generation: Option<u32>,
|
||||
secondary_conf: Option<LocationConfigSecondary>,
|
||||
) -> LocationConfig {
|
||||
LocationConfig {
|
||||
mode,
|
||||
generation,
|
||||
secondary_conf,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
shard_number: 0,
|
||||
shard_count: 0,
|
||||
shard_stripe_size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
let previous = attachment_service.inspect(tenant_id)?;
|
||||
let mut baseline_lsns = None;
|
||||
if let Some((generation, origin_ps_id)) = &previous {
|
||||
@@ -101,12 +116,7 @@ pub fn migrate_tenant(
|
||||
if origin_ps_id == &dest_ps.conf.id {
|
||||
println!("🔁 Already attached to {origin_ps_id}, freshening...");
|
||||
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
|
||||
let dest_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedSingle,
|
||||
generation: gen.map(Generation::new),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
let dest_conf = build_location_config(LocationConfigMode::AttachedSingle, gen, None);
|
||||
dest_ps.location_config(tenant_id, dest_conf)?;
|
||||
println!("✅ Migration complete");
|
||||
return Ok(());
|
||||
@@ -114,24 +124,15 @@ pub fn migrate_tenant(
|
||||
|
||||
println!("🔁 Switching origin pageserver {origin_ps_id} to stale mode");
|
||||
|
||||
let stale_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedStale,
|
||||
generation: Some(Generation::new(*generation)),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
let stale_conf =
|
||||
build_location_config(LocationConfigMode::AttachedStale, Some(*generation), None);
|
||||
origin_ps.location_config(tenant_id, stale_conf)?;
|
||||
|
||||
baseline_lsns = Some(get_lsns(tenant_id, &origin_ps)?);
|
||||
}
|
||||
|
||||
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
|
||||
let dest_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedMulti,
|
||||
generation: gen.map(Generation::new),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
let dest_conf = build_location_config(LocationConfigMode::AttachedMulti, gen, None);
|
||||
|
||||
println!("🔁 Attaching to pageserver {}", dest_ps.conf.id);
|
||||
dest_ps.location_config(tenant_id, dest_conf)?;
|
||||
@@ -170,12 +171,11 @@ pub fn migrate_tenant(
|
||||
}
|
||||
|
||||
// Downgrade to a secondary location
|
||||
let secondary_conf = LocationConfig {
|
||||
mode: LocationConfigMode::Secondary,
|
||||
generation: None,
|
||||
secondary_conf: Some(LocationConfigSecondary { warm: true }),
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
let secondary_conf = build_location_config(
|
||||
LocationConfigMode::Secondary,
|
||||
None,
|
||||
Some(LocationConfigSecondary { warm: true }),
|
||||
);
|
||||
|
||||
println!(
|
||||
"💤 Switching to secondary mode on pageserver {}",
|
||||
@@ -188,12 +188,7 @@ pub fn migrate_tenant(
|
||||
"🔁 Switching to AttachedSingle mode on pageserver {}",
|
||||
dest_ps.conf.id
|
||||
);
|
||||
let dest_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedSingle,
|
||||
generation: gen.map(Generation::new),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
let dest_conf = build_location_config(LocationConfigMode::AttachedSingle, gen, None);
|
||||
dest_ps.location_config(tenant_id, dest_conf)?;
|
||||
|
||||
println!("✅ Migration complete");
|
||||
|
||||
@@ -18,6 +18,7 @@ enum-map.workspace = true
|
||||
strum.workspace = true
|
||||
strum_macros.workspace = true
|
||||
hex.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ use serde_with::serde_as;
|
||||
use strum_macros;
|
||||
use utils::{
|
||||
completion,
|
||||
generation::Generation,
|
||||
history_buffer::HistoryBufferWithDropCounter,
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
@@ -262,10 +261,19 @@ pub struct LocationConfig {
|
||||
pub mode: LocationConfigMode,
|
||||
/// If attaching, in what generation?
|
||||
#[serde(default)]
|
||||
pub generation: Option<Generation>,
|
||||
pub generation: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub secondary_conf: Option<LocationConfigSecondary>,
|
||||
|
||||
// Shard parameters: if shard_count is nonzero, then other shard_* fields
|
||||
// must be set accurately.
|
||||
#[serde(default)]
|
||||
pub shard_number: u8,
|
||||
#[serde(default)]
|
||||
pub shard_count: u8,
|
||||
#[serde(default)]
|
||||
pub shard_stripe_size: u32,
|
||||
|
||||
// If requesting mode `Secondary`, configuration for that.
|
||||
// Custom storage configuration for the tenant, if any
|
||||
pub tenant_conf: TenantConfig,
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::{ops::RangeInclusive, str::FromStr};
|
||||
|
||||
use hex::FromHex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror;
|
||||
use utils::id::TenantId;
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug)]
|
||||
@@ -139,6 +140,89 @@ impl From<[u8; 18]> for TenantShardId {
|
||||
}
|
||||
}
|
||||
|
||||
/// For use within the context of a particular tenant, when we need to know which
|
||||
/// shard we're dealing with, but do not need to know the full ShardIdentity (because
|
||||
/// we won't be doing any page->shard mapping), and do not need to know the fully qualified
|
||||
/// TenantShardId.
|
||||
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy)]
|
||||
pub struct ShardIndex {
|
||||
pub shard_number: ShardNumber,
|
||||
pub shard_count: ShardCount,
|
||||
}
|
||||
|
||||
impl ShardIndex {
|
||||
pub fn new(number: ShardNumber, count: ShardCount) -> Self {
|
||||
Self {
|
||||
shard_number: number,
|
||||
shard_count: count,
|
||||
}
|
||||
}
|
||||
pub fn unsharded() -> Self {
|
||||
Self {
|
||||
shard_number: ShardNumber(0),
|
||||
shard_count: ShardCount(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_unsharded(&self) -> bool {
|
||||
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
|
||||
}
|
||||
|
||||
/// For use in constructing remote storage paths: concatenate this with a TenantId
|
||||
/// to get a fully qualified TenantShardId.
|
||||
///
|
||||
/// Backward compat: this function returns an empty string if Self::is_unsharded, such
|
||||
/// that the legacy pre-sharding remote key format is preserved.
|
||||
pub fn get_suffix(&self) -> String {
|
||||
if self.is_unsharded() {
|
||||
"".to_string()
|
||||
} else {
|
||||
format!("-{:02x}{:02x}", self.shard_number.0, self.shard_count.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ShardIndex {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:02x}{:02x}", self.shard_number.0, self.shard_count.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ShardIndex {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Debug is the same as Display: the compact hex representation
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for ShardIndex {
|
||||
type Err = hex::FromHexError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// Expect format: 1 byte shard number, 1 byte shard count
|
||||
if s.len() == 4 {
|
||||
let bytes = s.as_bytes();
|
||||
let mut shard_parts: [u8; 2] = [0u8; 2];
|
||||
hex::decode_to_slice(bytes, &mut shard_parts)?;
|
||||
Ok(Self {
|
||||
shard_number: ShardNumber(shard_parts[0]),
|
||||
shard_count: ShardCount(shard_parts[1]),
|
||||
})
|
||||
} else {
|
||||
Err(hex::FromHexError::InvalidStringLength)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; 2]> for ShardIndex {
|
||||
fn from(b: [u8; 2]) -> Self {
|
||||
Self {
|
||||
shard_number: ShardNumber(b[0]),
|
||||
shard_count: ShardCount(b[1]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for TenantShardId {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
@@ -209,6 +293,151 @@ impl<'de> Deserialize<'de> for TenantShardId {
|
||||
}
|
||||
}
|
||||
|
||||
/// Stripe size in number of pages
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
|
||||
pub struct ShardStripeSize(pub u32);
|
||||
|
||||
/// Layout version: for future upgrades where we might change how the key->shard mapping works
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
|
||||
pub struct ShardLayout(u8);
|
||||
|
||||
const LAYOUT_V1: ShardLayout = ShardLayout(1);
|
||||
|
||||
/// Default stripe size in pages: 256MiB divided by 8kiB page size.
|
||||
const DEFAULT_STRIPE_SIZE: ShardStripeSize = ShardStripeSize(256 * 1024 / 8);
|
||||
|
||||
/// The ShardIdentity contains the information needed for one member of map
|
||||
/// to resolve a key to a shard, and then check whether that shard is ==self.
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)]
|
||||
pub struct ShardIdentity {
|
||||
pub layout: ShardLayout,
|
||||
pub number: ShardNumber,
|
||||
pub count: ShardCount,
|
||||
pub stripe_size: ShardStripeSize,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
||||
pub enum ShardConfigError {
|
||||
#[error("Invalid shard count")]
|
||||
InvalidCount,
|
||||
#[error("Invalid shard number")]
|
||||
InvalidNumber,
|
||||
#[error("Invalid stripe size")]
|
||||
InvalidStripeSize,
|
||||
}
|
||||
|
||||
impl ShardIdentity {
|
||||
/// An identity with number=0 count=0 is a "none" identity, which represents legacy
|
||||
/// tenants. Modern single-shard tenants should not use this: they should
|
||||
/// have number=0 count=1.
|
||||
pub fn unsharded() -> Self {
|
||||
Self {
|
||||
number: ShardNumber(0),
|
||||
count: ShardCount(0),
|
||||
layout: LAYOUT_V1,
|
||||
stripe_size: DEFAULT_STRIPE_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_unsharded(&self) -> bool {
|
||||
self.number == ShardNumber(0) && self.count == ShardCount(0)
|
||||
}
|
||||
|
||||
/// Count must be nonzero, and number must be < count. To construct
|
||||
/// the legacy case (count==0), use Self::unsharded instead.
|
||||
pub fn new(
|
||||
number: ShardNumber,
|
||||
count: ShardCount,
|
||||
stripe_size: ShardStripeSize,
|
||||
) -> Result<Self, ShardConfigError> {
|
||||
if count.0 == 0 {
|
||||
Err(ShardConfigError::InvalidCount)
|
||||
} else if number.0 > count.0 - 1 {
|
||||
Err(ShardConfigError::InvalidNumber)
|
||||
} else if stripe_size.0 == 0 {
|
||||
Err(ShardConfigError::InvalidStripeSize)
|
||||
} else {
|
||||
Ok(Self {
|
||||
number,
|
||||
count,
|
||||
layout: LAYOUT_V1,
|
||||
stripe_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ShardIndex {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
if serializer.is_human_readable() {
|
||||
serializer.collect_str(self)
|
||||
} else {
|
||||
// Binary encoding is not used in index_part.json, but is included in anticipation of
|
||||
// switching various structures (e.g. inter-process communication, remote metadata) to more
|
||||
// compact binary encodings in future.
|
||||
let mut packed: [u8; 2] = [0; 2];
|
||||
packed[0] = self.shard_number.0;
|
||||
packed[1] = self.shard_count.0;
|
||||
packed.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ShardIndex {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct IdVisitor {
|
||||
is_human_readable_deserializer: bool,
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for IdVisitor {
|
||||
type Value = ShardIndex;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
if self.is_human_readable_deserializer {
|
||||
formatter.write_str("value in form of hex string")
|
||||
} else {
|
||||
formatter.write_str("value in form of integer array([u8; 2])")
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let s = serde::de::value::SeqAccessDeserializer::new(seq);
|
||||
let id: [u8; 2] = Deserialize::deserialize(s)?;
|
||||
Ok(ShardIndex::from(id))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
ShardIndex::from_str(v).map_err(E::custom)
|
||||
}
|
||||
}
|
||||
|
||||
if deserializer.is_human_readable() {
|
||||
deserializer.deserialize_str(IdVisitor {
|
||||
is_human_readable_deserializer: true,
|
||||
})
|
||||
} else {
|
||||
deserializer.deserialize_tuple(
|
||||
2,
|
||||
IdVisitor {
|
||||
is_human_readable_deserializer: false,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
@@ -318,4 +547,66 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shard_identity_validation() -> Result<(), ShardConfigError> {
|
||||
// Happy cases
|
||||
ShardIdentity::new(ShardNumber(0), ShardCount(1), DEFAULT_STRIPE_SIZE)?;
|
||||
ShardIdentity::new(ShardNumber(0), ShardCount(1), ShardStripeSize(1))?;
|
||||
ShardIdentity::new(ShardNumber(254), ShardCount(255), ShardStripeSize(1))?;
|
||||
|
||||
assert_eq!(
|
||||
ShardIdentity::new(ShardNumber(0), ShardCount(0), DEFAULT_STRIPE_SIZE),
|
||||
Err(ShardConfigError::InvalidCount)
|
||||
);
|
||||
assert_eq!(
|
||||
ShardIdentity::new(ShardNumber(10), ShardCount(10), DEFAULT_STRIPE_SIZE),
|
||||
Err(ShardConfigError::InvalidNumber)
|
||||
);
|
||||
assert_eq!(
|
||||
ShardIdentity::new(ShardNumber(11), ShardCount(10), DEFAULT_STRIPE_SIZE),
|
||||
Err(ShardConfigError::InvalidNumber)
|
||||
);
|
||||
assert_eq!(
|
||||
ShardIdentity::new(ShardNumber(255), ShardCount(255), DEFAULT_STRIPE_SIZE),
|
||||
Err(ShardConfigError::InvalidNumber)
|
||||
);
|
||||
assert_eq!(
|
||||
ShardIdentity::new(ShardNumber(0), ShardCount(1), ShardStripeSize(0)),
|
||||
Err(ShardConfigError::InvalidStripeSize)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shard_index_human_encoding() -> Result<(), hex::FromHexError> {
|
||||
let example = ShardIndex {
|
||||
shard_number: ShardNumber(13),
|
||||
shard_count: ShardCount(17),
|
||||
};
|
||||
let expected: String = "0d11".to_string();
|
||||
let encoded = format!("{example}");
|
||||
assert_eq!(&encoded, &expected);
|
||||
|
||||
let decoded = ShardIndex::from_str(&encoded)?;
|
||||
assert_eq!(example, decoded);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shard_index_binary_encoding() -> Result<(), hex::FromHexError> {
|
||||
let example = ShardIndex {
|
||||
shard_number: ShardNumber(13),
|
||||
shard_count: ShardCount(17),
|
||||
};
|
||||
let expected: [u8; 2] = [0x0d, 0x11];
|
||||
|
||||
let encoded = bincode::serialize(&example).unwrap();
|
||||
assert_eq!(Hex(&encoded), Hex(&expected));
|
||||
let decoded = bincode::deserialize(&encoded).unwrap();
|
||||
assert_eq!(example, decoded);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::control_plane_client::ControlPlaneGenerationsApi;
|
||||
use crate::metrics;
|
||||
use crate::tenant::remote_timeline_client::remote_layer_path;
|
||||
use crate::tenant::remote_timeline_client::remote_timeline_path;
|
||||
use crate::tenant::remote_timeline_client::LayerFileMetadata;
|
||||
use crate::virtual_file::MaybeFatalIo;
|
||||
use crate::virtual_file::VirtualFile;
|
||||
use anyhow::Context;
|
||||
@@ -509,18 +510,19 @@ impl DeletionQueueClient {
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
current_generation: Generation,
|
||||
layers: Vec<(LayerFileName, Generation)>,
|
||||
layers: Vec<(LayerFileName, LayerFileMetadata)>,
|
||||
) -> Result<(), DeletionQueueError> {
|
||||
if current_generation.is_none() {
|
||||
debug!("Enqueuing deletions in legacy mode, skipping queue");
|
||||
|
||||
let mut layer_paths = Vec::new();
|
||||
for (layer, generation) in layers {
|
||||
for (layer, meta) in layers {
|
||||
layer_paths.push(remote_layer_path(
|
||||
&tenant_id,
|
||||
&timeline_id,
|
||||
meta.shard,
|
||||
&layer,
|
||||
generation,
|
||||
meta.generation,
|
||||
));
|
||||
}
|
||||
self.push_immediate(layer_paths).await?;
|
||||
@@ -540,7 +542,7 @@ impl DeletionQueueClient {
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
current_generation: Generation,
|
||||
layers: Vec<(LayerFileName, Generation)>,
|
||||
layers: Vec<(LayerFileName, LayerFileMetadata)>,
|
||||
) -> Result<(), DeletionQueueError> {
|
||||
metrics::DELETION_QUEUE
|
||||
.keys_submitted
|
||||
@@ -751,6 +753,7 @@ impl DeletionQueue {
|
||||
mod test {
|
||||
use camino::Utf8Path;
|
||||
use hex_literal::hex;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use std::{io::ErrorKind, time::Duration};
|
||||
use tracing::info;
|
||||
|
||||
@@ -990,6 +993,8 @@ mod test {
|
||||
// we delete, and the generation of the running Tenant.
|
||||
let layer_generation = Generation::new(0xdeadbeef);
|
||||
let now_generation = Generation::new(0xfeedbeef);
|
||||
let layer_metadata =
|
||||
LayerFileMetadata::new(0xf00, layer_generation, ShardIndex::unsharded());
|
||||
|
||||
let remote_layer_file_name_1 =
|
||||
format!("{}{}", layer_file_name_1, layer_generation.get_suffix());
|
||||
@@ -1013,7 +1018,7 @@ mod test {
|
||||
tenant_id,
|
||||
TIMELINE_ID,
|
||||
now_generation,
|
||||
[(layer_file_name_1.clone(), layer_generation)].to_vec(),
|
||||
[(layer_file_name_1.clone(), layer_metadata)].to_vec(),
|
||||
)
|
||||
.await?;
|
||||
assert_remote_files(&[&remote_layer_file_name_1], &remote_timeline_path);
|
||||
@@ -1052,6 +1057,8 @@ mod test {
|
||||
let stale_generation = latest_generation.previous();
|
||||
// Generation that our example layer file was written with
|
||||
let layer_generation = stale_generation.previous();
|
||||
let layer_metadata =
|
||||
LayerFileMetadata::new(0xf00, layer_generation, ShardIndex::unsharded());
|
||||
|
||||
ctx.set_latest_generation(latest_generation);
|
||||
|
||||
@@ -1069,7 +1076,7 @@ mod test {
|
||||
tenant_id,
|
||||
TIMELINE_ID,
|
||||
stale_generation,
|
||||
[(EXAMPLE_LAYER_NAME.clone(), layer_generation)].to_vec(),
|
||||
[(EXAMPLE_LAYER_NAME.clone(), layer_metadata.clone())].to_vec(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1084,7 +1091,7 @@ mod test {
|
||||
tenant_id,
|
||||
TIMELINE_ID,
|
||||
latest_generation,
|
||||
[(EXAMPLE_LAYER_NAME.clone(), layer_generation)].to_vec(),
|
||||
[(EXAMPLE_LAYER_NAME.clone(), layer_metadata.clone())].to_vec(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1111,6 +1118,8 @@ mod test {
|
||||
|
||||
let layer_generation = Generation::new(0xdeadbeef);
|
||||
let now_generation = Generation::new(0xfeedbeef);
|
||||
let layer_metadata =
|
||||
LayerFileMetadata::new(0xf00, layer_generation, ShardIndex::unsharded());
|
||||
|
||||
// Inject a deletion in the generation before generation_now: after restart,
|
||||
// this deletion should _not_ get executed (only the immediately previous
|
||||
@@ -1122,7 +1131,7 @@ mod test {
|
||||
tenant_id,
|
||||
TIMELINE_ID,
|
||||
now_generation.previous(),
|
||||
[(EXAMPLE_LAYER_NAME.clone(), layer_generation)].to_vec(),
|
||||
[(EXAMPLE_LAYER_NAME.clone(), layer_metadata.clone())].to_vec(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1136,7 +1145,7 @@ mod test {
|
||||
tenant_id,
|
||||
TIMELINE_ID,
|
||||
now_generation,
|
||||
[(EXAMPLE_LAYER_NAME_ALT.clone(), layer_generation)].to_vec(),
|
||||
[(EXAMPLE_LAYER_NAME_ALT.clone(), layer_metadata.clone())].to_vec(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1226,12 +1235,13 @@ pub(crate) mod mock {
|
||||
match msg {
|
||||
ListWriterQueueMessage::Delete(op) => {
|
||||
let mut objects = op.objects;
|
||||
for (layer, generation) in op.layers {
|
||||
for (layer, meta) in op.layers {
|
||||
objects.push(remote_layer_path(
|
||||
&op.tenant_id,
|
||||
&op.timeline_id,
|
||||
meta.shard,
|
||||
&layer,
|
||||
generation,
|
||||
meta.generation,
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ use crate::config::PageServerConf;
|
||||
use crate::deletion_queue::TEMP_SUFFIX;
|
||||
use crate::metrics;
|
||||
use crate::tenant::remote_timeline_client::remote_layer_path;
|
||||
use crate::tenant::remote_timeline_client::LayerFileMetadata;
|
||||
use crate::tenant::storage_layer::LayerFileName;
|
||||
use crate::virtual_file::on_fatal_io_error;
|
||||
use crate::virtual_file::MaybeFatalIo;
|
||||
@@ -58,7 +59,7 @@ pub(super) struct DeletionOp {
|
||||
// `layers` and `objects` are both just lists of objects. `layers` is used if you do not
|
||||
// have a config object handy to project it to a remote key, and need the consuming worker
|
||||
// to do it for you.
|
||||
pub(super) layers: Vec<(LayerFileName, Generation)>,
|
||||
pub(super) layers: Vec<(LayerFileName, LayerFileMetadata)>,
|
||||
pub(super) objects: Vec<RemotePath>,
|
||||
|
||||
/// The _current_ generation of the Tenant attachment in which we are enqueuing
|
||||
@@ -387,12 +388,13 @@ impl ListWriter {
|
||||
);
|
||||
|
||||
let mut layer_paths = Vec::new();
|
||||
for (layer, generation) in op.layers {
|
||||
for (layer, meta) in op.layers {
|
||||
layer_paths.push(remote_layer_path(
|
||||
&op.tenant_id,
|
||||
&op.timeline_id,
|
||||
meta.shard,
|
||||
&layer,
|
||||
generation,
|
||||
meta.generation,
|
||||
));
|
||||
}
|
||||
layer_paths.extend(op.objects);
|
||||
|
||||
@@ -19,6 +19,7 @@ use futures::FutureExt;
|
||||
use pageserver_api::models::TimelineState;
|
||||
use remote_storage::DownloadError;
|
||||
use remote_storage::GenericRemoteStorage;
|
||||
use std::fmt;
|
||||
use storage_broker::BrokerClientChannel;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::watch;
|
||||
@@ -31,26 +32,6 @@ use utils::crashsafe::path_with_suffix_extension;
|
||||
use utils::fs_ext;
|
||||
use utils::sync::gate::Gate;
|
||||
|
||||
use std::cmp::min;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::BTreeSet;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Debug;
|
||||
use std::fmt::Display;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
use std::ops::Bound::Included;
|
||||
use std::process::Command;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use std::sync::MutexGuard;
|
||||
use std::sync::{Mutex, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use self::config::AttachedLocationConfig;
|
||||
use self::config::AttachmentMode;
|
||||
use self::config::LocationConf;
|
||||
@@ -84,14 +65,35 @@ use crate::tenant::remote_timeline_client::MaybeDeletedIndexPart;
|
||||
use crate::tenant::storage_layer::DeltaLayer;
|
||||
use crate::tenant::storage_layer::ImageLayer;
|
||||
use crate::InitializationOrder;
|
||||
use std::cmp::min;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::BTreeSet;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Debug;
|
||||
use std::fmt::Display;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
use std::ops::Bound::Included;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use std::sync::MutexGuard;
|
||||
use std::sync::{Mutex, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::tenant::timeline::delete::DeleteTimelineFlow;
|
||||
use crate::tenant::timeline::uninit::cleanup_timeline_directory;
|
||||
use crate::virtual_file::VirtualFile;
|
||||
use crate::walredo::PostgresRedoManager;
|
||||
use crate::TEMP_FILE_SUFFIX;
|
||||
use once_cell::sync::Lazy;
|
||||
pub use pageserver_api::models::TenantState;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
static INIT_DB_SEMAPHORE: Lazy<Semaphore> = Lazy::new(|| Semaphore::new(8));
|
||||
use toml_edit;
|
||||
use utils::{
|
||||
crashsafe,
|
||||
@@ -403,6 +405,36 @@ pub enum CreateTimelineError {
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
enum InitdbError {
|
||||
Other(anyhow::Error),
|
||||
Cancelled,
|
||||
Spawn(std::io::Result<()>),
|
||||
Failed(std::process::ExitStatus, Vec<u8>),
|
||||
}
|
||||
|
||||
impl fmt::Display for InitdbError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
InitdbError::Cancelled => write!(f, "Operation was cancelled"),
|
||||
InitdbError::Spawn(e) => write!(f, "Spawn error: {:?}", e),
|
||||
InitdbError::Failed(status, stderr) => write!(
|
||||
f,
|
||||
"Command failed with status {:?}: {}",
|
||||
status,
|
||||
String::from_utf8_lossy(stderr)
|
||||
),
|
||||
InitdbError::Other(e) => write!(f, "Error: {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for InitdbError {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
InitdbError::Spawn(Err(error))
|
||||
}
|
||||
}
|
||||
|
||||
struct TenantDirectoryScan {
|
||||
sorted_timelines_to_load: Vec<(TimelineId, TimelineMetadata)>,
|
||||
timelines_to_resume_deletion: Vec<(TimelineId, Option<TimelineMetadata>)>,
|
||||
@@ -618,12 +650,21 @@ impl Tenant {
|
||||
// Remote preload is complete.
|
||||
drop(remote_load_completion);
|
||||
|
||||
let pending_deletion = DeleteTenantFlow::should_resume_deletion(
|
||||
conf,
|
||||
preload.as_ref().map(|p| p.deleting).unwrap_or(false),
|
||||
&tenant_clone,
|
||||
)
|
||||
.await;
|
||||
let pending_deletion = {
|
||||
match DeleteTenantFlow::should_resume_deletion(
|
||||
conf,
|
||||
preload.as_ref().map(|p| p.deleting).unwrap_or(false),
|
||||
&tenant_clone,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(should_resume_deletion) => should_resume_deletion,
|
||||
Err(err) => {
|
||||
make_broken(&tenant_clone, anyhow::anyhow!(err));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
info!("pending_deletion {}", pending_deletion.is_some());
|
||||
|
||||
@@ -724,7 +765,7 @@ impl Tenant {
|
||||
///
|
||||
async fn attach(
|
||||
self: &Arc<Tenant>,
|
||||
mut init_order: Option<InitializationOrder>,
|
||||
init_order: Option<InitializationOrder>,
|
||||
preload: Option<TenantPreload>,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -741,11 +782,6 @@ impl Tenant {
|
||||
}
|
||||
};
|
||||
|
||||
// Signal that we have completed remote phase
|
||||
init_order
|
||||
.as_mut()
|
||||
.and_then(|x| x.initial_tenant_load_remote.take());
|
||||
|
||||
let mut timelines_to_resume_deletions = vec![];
|
||||
|
||||
let mut remote_index_and_client = HashMap::new();
|
||||
@@ -2904,7 +2940,7 @@ impl Tenant {
|
||||
};
|
||||
// create a `tenant/{tenant_id}/timelines/basebackup-{timeline_id}.{TEMP_FILE_SUFFIX}/`
|
||||
// temporary directory for basebackup files for the given timeline.
|
||||
let initdb_path = path_with_suffix_extension(
|
||||
let pgdata_path = path_with_suffix_extension(
|
||||
self.conf
|
||||
.timelines_path(&self.tenant_id)
|
||||
.join(format!("basebackup-{timeline_id}")),
|
||||
@@ -2913,26 +2949,25 @@ impl Tenant {
|
||||
|
||||
// an uninit mark was placed before, nothing else can access this timeline files
|
||||
// current initdb was not run yet, so remove whatever was left from the previous runs
|
||||
if initdb_path.exists() {
|
||||
fs::remove_dir_all(&initdb_path).with_context(|| {
|
||||
format!("Failed to remove already existing initdb directory: {initdb_path}")
|
||||
if pgdata_path.exists() {
|
||||
fs::remove_dir_all(&pgdata_path).with_context(|| {
|
||||
format!("Failed to remove already existing initdb directory: {pgdata_path}")
|
||||
})?;
|
||||
}
|
||||
// Init temporarily repo to get bootstrap data, this creates a directory in the `initdb_path` path
|
||||
run_initdb(self.conf, &initdb_path, pg_version)?;
|
||||
run_initdb(self.conf, &pgdata_path, pg_version, &self.cancel).await?;
|
||||
// this new directory is very temporary, set to remove it immediately after bootstrap, we don't need it
|
||||
scopeguard::defer! {
|
||||
if let Err(e) = fs::remove_dir_all(&initdb_path) {
|
||||
if let Err(e) = fs::remove_dir_all(&pgdata_path) {
|
||||
// this is unlikely, but we will remove the directory on pageserver restart or another bootstrap call
|
||||
error!("Failed to remove temporary initdb directory '{initdb_path}': {e}");
|
||||
error!("Failed to remove temporary initdb directory '{pgdata_path}': {e}");
|
||||
}
|
||||
}
|
||||
let pgdata_path = &initdb_path;
|
||||
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(pgdata_path)?.align();
|
||||
let pgdata_lsn = import_datadir::get_lsn_from_controlfile(&pgdata_path)?.align();
|
||||
|
||||
// Upload the created data dir to S3
|
||||
if let Some(storage) = &self.remote_storage {
|
||||
let pgdata_zstd = import_datadir::create_tar_zst(pgdata_path).await?;
|
||||
let pgdata_zstd = import_datadir::create_tar_zst(&pgdata_path).await?;
|
||||
let pgdata_zstd = Bytes::from(pgdata_zstd);
|
||||
backoff::retry(
|
||||
|| async {
|
||||
@@ -2982,7 +3017,7 @@ impl Tenant {
|
||||
|
||||
import_datadir::import_timeline_from_postgres_datadir(
|
||||
unfinished_timeline,
|
||||
pgdata_path,
|
||||
&pgdata_path,
|
||||
pgdata_lsn,
|
||||
ctx,
|
||||
)
|
||||
@@ -3384,42 +3419,54 @@ fn rebase_directory(
|
||||
|
||||
/// Create the cluster temporarily in 'initdbpath' directory inside the repository
|
||||
/// to get bootstrap data for timeline initialization.
|
||||
fn run_initdb(
|
||||
async fn run_initdb(
|
||||
conf: &'static PageServerConf,
|
||||
initdb_target_dir: &Utf8Path,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<()> {
|
||||
let initdb_bin_path = conf.pg_bin_dir(pg_version)?.join("initdb");
|
||||
let initdb_lib_dir = conf.pg_lib_dir(pg_version)?;
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<(), InitdbError> {
|
||||
let initdb_bin_path = conf
|
||||
.pg_bin_dir(pg_version)
|
||||
.map_err(InitdbError::Other)?
|
||||
.join("initdb");
|
||||
let initdb_lib_dir = conf.pg_lib_dir(pg_version).map_err(InitdbError::Other)?;
|
||||
info!(
|
||||
"running {} in {}, libdir: {}",
|
||||
initdb_bin_path, initdb_target_dir, initdb_lib_dir,
|
||||
);
|
||||
|
||||
let initdb_output = Command::new(&initdb_bin_path)
|
||||
let _permit = INIT_DB_SEMAPHORE.acquire().await;
|
||||
|
||||
let initdb_command = tokio::process::Command::new(&initdb_bin_path)
|
||||
.args(["-D", initdb_target_dir.as_ref()])
|
||||
.args(["-U", &conf.superuser])
|
||||
.args(["-E", "utf8"])
|
||||
.arg("--no-instructions")
|
||||
// This is only used for a temporary installation that is deleted shortly after,
|
||||
// so no need to fsync it
|
||||
.arg("--no-sync")
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", &initdb_lib_dir)
|
||||
.env("DYLD_LIBRARY_PATH", &initdb_lib_dir)
|
||||
.stdout(Stdio::null())
|
||||
.output()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to execute {} at target dir {}",
|
||||
initdb_bin_path, initdb_target_dir,
|
||||
)
|
||||
})?;
|
||||
if !initdb_output.status.success() {
|
||||
bail!(
|
||||
"initdb failed: '{}'",
|
||||
String::from_utf8_lossy(&initdb_output.stderr)
|
||||
);
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
// If the `select!` below doesn't finish the `wait_with_output`,
|
||||
// let the task get `wait()`ed for asynchronously by tokio.
|
||||
// This means there is a slim chance we can go over the INIT_DB_SEMAPHORE.
|
||||
// TODO: fix for this is non-trivial, see
|
||||
// https://github.com/neondatabase/neon/pull/5921#pullrequestreview-1750858021
|
||||
//
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
tokio::select! {
|
||||
initdb_output = initdb_command.wait_with_output() => {
|
||||
let initdb_output = initdb_output?;
|
||||
if !initdb_output.status.success() {
|
||||
return Err(InitdbError::Failed(initdb_output.status, initdb_output.stderr));
|
||||
}
|
||||
}
|
||||
_ = cancel.cancelled() => {
|
||||
return Err(InitdbError::Cancelled);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -3465,6 +3512,7 @@ pub async fn dump_layerfile_from_path(
|
||||
pub(crate) mod harness {
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use once_cell::sync::OnceCell;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
use utils::logging;
|
||||
@@ -3531,6 +3579,7 @@ pub(crate) mod harness {
|
||||
pub tenant_conf: TenantConf,
|
||||
pub tenant_id: TenantId,
|
||||
pub generation: Generation,
|
||||
pub shard: ShardIndex,
|
||||
pub remote_storage: GenericRemoteStorage,
|
||||
pub remote_fs_dir: Utf8PathBuf,
|
||||
pub deletion_queue: MockDeletionQueue,
|
||||
@@ -3590,6 +3639,7 @@ pub(crate) mod harness {
|
||||
tenant_conf,
|
||||
tenant_id,
|
||||
generation: Generation::new(0xdeadbeef),
|
||||
shard: ShardIndex::unsharded(),
|
||||
remote_storage,
|
||||
remote_fs_dir,
|
||||
deletion_queue,
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
//!
|
||||
use anyhow::Context;
|
||||
use pageserver_api::models;
|
||||
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::num::NonZeroU64;
|
||||
use std::time::Duration;
|
||||
@@ -88,6 +89,14 @@ pub(crate) struct LocationConf {
|
||||
/// The location-specific part of the configuration, describes the operating
|
||||
/// mode of this pageserver for this tenant.
|
||||
pub(crate) mode: LocationMode,
|
||||
|
||||
/// The detailed shard identity. This structure is already scoped within
|
||||
/// a TenantShardId, but we need the full ShardIdentity to enable calculating
|
||||
/// key->shard mappings.
|
||||
#[serde(default = "ShardIdentity::unsharded")]
|
||||
#[serde(skip_serializing_if = "ShardIdentity::is_unsharded")]
|
||||
pub(crate) shard: ShardIdentity,
|
||||
|
||||
/// The pan-cluster tenant configuration, the same on all locations
|
||||
pub(crate) tenant_conf: TenantConfOpt,
|
||||
}
|
||||
@@ -160,6 +169,8 @@ impl LocationConf {
|
||||
generation,
|
||||
attach_mode: AttachmentMode::Single,
|
||||
}),
|
||||
// Legacy configuration loads are always from tenants created before sharding existed.
|
||||
shard: ShardIdentity::unsharded(),
|
||||
tenant_conf,
|
||||
}
|
||||
}
|
||||
@@ -187,6 +198,7 @@ impl LocationConf {
|
||||
|
||||
fn get_generation(conf: &'_ models::LocationConfig) -> Result<Generation, anyhow::Error> {
|
||||
conf.generation
|
||||
.map(Generation::new)
|
||||
.ok_or_else(|| anyhow::anyhow!("Generation must be set when attaching"))
|
||||
}
|
||||
|
||||
@@ -226,7 +238,21 @@ impl LocationConf {
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self { mode, tenant_conf })
|
||||
let shard = if conf.shard_count == 0 {
|
||||
ShardIdentity::unsharded()
|
||||
} else {
|
||||
ShardIdentity::new(
|
||||
ShardNumber(conf.shard_number),
|
||||
ShardCount(conf.shard_count),
|
||||
ShardStripeSize(conf.shard_stripe_size),
|
||||
)?
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
shard,
|
||||
mode,
|
||||
tenant_conf,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,6 +267,7 @@ impl Default for LocationConf {
|
||||
attach_mode: AttachmentMode::Single,
|
||||
}),
|
||||
tenant_conf: TenantConfOpt::default(),
|
||||
shard: ShardIdentity::unsharded(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -361,17 +361,25 @@ impl DeleteTenantFlow {
|
||||
conf: &'static PageServerConf,
|
||||
remote_mark_exists: bool,
|
||||
tenant: &Tenant,
|
||||
) -> Option<DeletionGuard> {
|
||||
let tenant_id = tenant.tenant_id;
|
||||
|
||||
if remote_mark_exists || conf.tenant_deleted_mark_file_path(&tenant_id).exists() {
|
||||
) -> Result<Option<DeletionGuard>, DeleteTenantError> {
|
||||
let acquire = |t: &Tenant| {
|
||||
Some(
|
||||
Arc::clone(&tenant.delete_progress)
|
||||
Arc::clone(&t.delete_progress)
|
||||
.try_lock_owned()
|
||||
.expect("we're the only owner during init"),
|
||||
)
|
||||
};
|
||||
|
||||
if remote_mark_exists {
|
||||
return Ok(acquire(tenant));
|
||||
}
|
||||
|
||||
let tenant_id = tenant.tenant_id;
|
||||
// Check local mark first, if its there there is no need to go to s3 to check whether remote one exists.
|
||||
if conf.tenant_deleted_mark_file_path(&tenant_id).exists() {
|
||||
Ok(acquire(tenant))
|
||||
} else {
|
||||
None
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use scopeguard::ScopeGuard;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
pub(crate) use upload::upload_initdb_dir;
|
||||
@@ -402,6 +403,11 @@ impl RemoteTimelineClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get_shard_index(&self) -> ShardIndex {
|
||||
// TODO: carry this on the struct
|
||||
ShardIndex::unsharded()
|
||||
}
|
||||
|
||||
pub fn remote_consistent_lsn_projected(&self) -> Option<Lsn> {
|
||||
match &mut *self.upload_queue.lock().unwrap() {
|
||||
UploadQueue::Uninitialized => None,
|
||||
@@ -465,6 +471,7 @@ impl RemoteTimelineClient {
|
||||
&self.storage_impl,
|
||||
&self.tenant_id,
|
||||
&self.timeline_id,
|
||||
self.get_shard_index(),
|
||||
self.generation,
|
||||
cancel,
|
||||
)
|
||||
@@ -657,10 +664,10 @@ impl RemoteTimelineClient {
|
||||
let mut guard = self.upload_queue.lock().unwrap();
|
||||
let upload_queue = guard.initialized_mut()?;
|
||||
|
||||
let with_generations =
|
||||
let with_metadata =
|
||||
self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names.iter().cloned());
|
||||
|
||||
self.schedule_deletion_of_unlinked0(upload_queue, with_generations);
|
||||
self.schedule_deletion_of_unlinked0(upload_queue, with_metadata);
|
||||
|
||||
// Launch the tasks immediately, if possible
|
||||
self.launch_queued_tasks(upload_queue);
|
||||
@@ -695,7 +702,7 @@ impl RemoteTimelineClient {
|
||||
self: &Arc<Self>,
|
||||
upload_queue: &mut UploadQueueInitialized,
|
||||
names: I,
|
||||
) -> Vec<(LayerFileName, Generation)>
|
||||
) -> Vec<(LayerFileName, LayerFileMetadata)>
|
||||
where
|
||||
I: IntoIterator<Item = LayerFileName>,
|
||||
{
|
||||
@@ -703,16 +710,17 @@ impl RemoteTimelineClient {
|
||||
// so we don't need update it. Just serialize it.
|
||||
let metadata = upload_queue.latest_metadata.clone();
|
||||
|
||||
// Decorate our list of names with each name's generation, dropping
|
||||
// names that are unexpectedly missing from our metadata.
|
||||
let with_generations: Vec<_> = names
|
||||
// Decorate our list of names with each name's metadata, dropping
|
||||
// names that are unexpectedly missing from our metadata. This metadata
|
||||
// is later used when physically deleting layers, to construct key paths.
|
||||
let with_metadata: Vec<_> = names
|
||||
.into_iter()
|
||||
.filter_map(|name| {
|
||||
let meta = upload_queue.latest_files.remove(&name);
|
||||
|
||||
if let Some(meta) = meta {
|
||||
upload_queue.latest_files_changes_since_metadata_upload_scheduled += 1;
|
||||
Some((name, meta.generation))
|
||||
Some((name, meta))
|
||||
} else {
|
||||
// This can only happen if we forgot to to schedule the file upload
|
||||
// before scheduling the delete. Log it because it is a rare/strange
|
||||
@@ -725,9 +733,10 @@ impl RemoteTimelineClient {
|
||||
.collect();
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
for (name, gen) in &with_generations {
|
||||
if let Some(unexpected) = upload_queue.dangling_files.insert(name.to_owned(), *gen) {
|
||||
if &unexpected == gen {
|
||||
for (name, metadata) in &with_metadata {
|
||||
let gen = metadata.generation;
|
||||
if let Some(unexpected) = upload_queue.dangling_files.insert(name.to_owned(), gen) {
|
||||
if unexpected == gen {
|
||||
tracing::error!("{name} was unlinked twice with same generation");
|
||||
} else {
|
||||
tracing::error!("{name} was unlinked twice with different generations {gen:?} and {unexpected:?}");
|
||||
@@ -742,14 +751,14 @@ impl RemoteTimelineClient {
|
||||
self.schedule_index_upload(upload_queue, metadata);
|
||||
}
|
||||
|
||||
with_generations
|
||||
with_metadata
|
||||
}
|
||||
|
||||
/// Schedules deletion for layer files which have previously been unlinked from the
|
||||
/// `index_part.json` with [`Self::schedule_gc_update`] or [`Self::schedule_compaction_update`].
|
||||
pub(crate) fn schedule_deletion_of_unlinked(
|
||||
self: &Arc<Self>,
|
||||
layers: Vec<(LayerFileName, Generation)>,
|
||||
layers: Vec<(LayerFileName, LayerFileMetadata)>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut guard = self.upload_queue.lock().unwrap();
|
||||
let upload_queue = guard.initialized_mut()?;
|
||||
@@ -762,16 +771,22 @@ impl RemoteTimelineClient {
|
||||
fn schedule_deletion_of_unlinked0(
|
||||
self: &Arc<Self>,
|
||||
upload_queue: &mut UploadQueueInitialized,
|
||||
with_generations: Vec<(LayerFileName, Generation)>,
|
||||
with_metadata: Vec<(LayerFileName, LayerFileMetadata)>,
|
||||
) {
|
||||
for (name, gen) in &with_generations {
|
||||
info!("scheduling deletion of layer {}{}", name, gen.get_suffix());
|
||||
for (name, meta) in &with_metadata {
|
||||
info!(
|
||||
"scheduling deletion of layer {}{} (shard {})",
|
||||
name,
|
||||
meta.generation.get_suffix(),
|
||||
meta.shard
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
for (name, gen) in &with_generations {
|
||||
for (name, meta) in &with_metadata {
|
||||
let gen = meta.generation;
|
||||
match upload_queue.dangling_files.remove(name) {
|
||||
Some(same) if &same == gen => { /* expected */ }
|
||||
Some(same) if same == gen => { /* expected */ }
|
||||
Some(other) => {
|
||||
tracing::error!("{name} was unlinked with {other:?} but deleted with {gen:?}");
|
||||
}
|
||||
@@ -783,7 +798,7 @@ impl RemoteTimelineClient {
|
||||
|
||||
// schedule the actual deletions
|
||||
let op = UploadOp::Delete(Delete {
|
||||
layers: with_generations,
|
||||
layers: with_metadata,
|
||||
});
|
||||
self.calls_unfinished_metric_begin(&op);
|
||||
upload_queue.queued_operations.push_back(op);
|
||||
@@ -904,6 +919,7 @@ impl RemoteTimelineClient {
|
||||
&self.storage_impl,
|
||||
&self.tenant_id,
|
||||
&self.timeline_id,
|
||||
self.get_shard_index(),
|
||||
self.generation,
|
||||
&index_part_with_deleted_at,
|
||||
)
|
||||
@@ -962,6 +978,7 @@ impl RemoteTimelineClient {
|
||||
remote_layer_path(
|
||||
&self.tenant_id,
|
||||
&self.timeline_id,
|
||||
meta.shard,
|
||||
&file_name,
|
||||
meta.generation,
|
||||
)
|
||||
@@ -1010,7 +1027,12 @@ impl RemoteTimelineClient {
|
||||
.unwrap_or(
|
||||
// No generation-suffixed indices, assume we are dealing with
|
||||
// a legacy index.
|
||||
remote_index_path(&self.tenant_id, &self.timeline_id, Generation::none()),
|
||||
remote_index_path(
|
||||
&self.tenant_id,
|
||||
&self.timeline_id,
|
||||
self.get_shard_index(),
|
||||
Generation::none(),
|
||||
),
|
||||
);
|
||||
|
||||
let remaining_layers: Vec<RemotePath> = remaining
|
||||
@@ -1219,6 +1241,7 @@ impl RemoteTimelineClient {
|
||||
&self.storage_impl,
|
||||
&self.tenant_id,
|
||||
&self.timeline_id,
|
||||
self.get_shard_index(),
|
||||
self.generation,
|
||||
index_part,
|
||||
)
|
||||
@@ -1527,12 +1550,14 @@ pub fn remote_timeline_path(tenant_id: &TenantId, timeline_id: &TimelineId) -> R
|
||||
pub fn remote_layer_path(
|
||||
tenant_id: &TenantId,
|
||||
timeline_id: &TimelineId,
|
||||
shard: ShardIndex,
|
||||
layer_file_name: &LayerFileName,
|
||||
generation: Generation,
|
||||
) -> RemotePath {
|
||||
// Generation-aware key format
|
||||
let path = format!(
|
||||
"tenants/{tenant_id}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{0}{1}",
|
||||
"tenants/{tenant_id}{0}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{1}{2}",
|
||||
shard.get_suffix(),
|
||||
layer_file_name.file_name(),
|
||||
generation.get_suffix()
|
||||
);
|
||||
@@ -1550,10 +1575,12 @@ pub fn remote_initdb_archive_path(tenant_id: &TenantId, timeline_id: &TimelineId
|
||||
pub fn remote_index_path(
|
||||
tenant_id: &TenantId,
|
||||
timeline_id: &TimelineId,
|
||||
shard: ShardIndex,
|
||||
generation: Generation,
|
||||
) -> RemotePath {
|
||||
RemotePath::from_string(&format!(
|
||||
"tenants/{tenant_id}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{0}{1}",
|
||||
"tenants/{tenant_id}{0}/{TIMELINES_SEGMENT_NAME}/{timeline_id}/{1}{2}",
|
||||
shard.get_suffix(),
|
||||
IndexPart::FILE_NAME,
|
||||
generation.get_suffix()
|
||||
))
|
||||
@@ -1778,6 +1805,7 @@ mod tests {
|
||||
println!("remote_timeline_dir: {remote_timeline_dir}");
|
||||
|
||||
let generation = harness.generation;
|
||||
let shard = harness.shard;
|
||||
|
||||
// Create a couple of dummy files, schedule upload for them
|
||||
|
||||
@@ -1794,7 +1822,7 @@ mod tests {
|
||||
harness.conf,
|
||||
&timeline,
|
||||
name,
|
||||
LayerFileMetadata::new(contents.len() as u64, generation),
|
||||
LayerFileMetadata::new(contents.len() as u64, generation, shard),
|
||||
)
|
||||
}).collect::<Vec<_>>();
|
||||
|
||||
@@ -1943,7 +1971,7 @@ mod tests {
|
||||
harness.conf,
|
||||
&timeline,
|
||||
layer_file_name_1.clone(),
|
||||
LayerFileMetadata::new(content_1.len() as u64, harness.generation),
|
||||
LayerFileMetadata::new(content_1.len() as u64, harness.generation, harness.shard),
|
||||
);
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
@@ -2008,7 +2036,11 @@ mod tests {
|
||||
assert_eq!(actual_c, expected_c);
|
||||
}
|
||||
|
||||
async fn inject_index_part(test_state: &TestSetup, generation: Generation) -> IndexPart {
|
||||
async fn inject_index_part(
|
||||
test_state: &TestSetup,
|
||||
generation: Generation,
|
||||
shard: ShardIndex,
|
||||
) -> IndexPart {
|
||||
// An empty IndexPart, just sufficient to ensure deserialization will succeed
|
||||
let example_metadata = TimelineMetadata::example();
|
||||
let example_index_part = IndexPart::new(
|
||||
@@ -2029,7 +2061,13 @@ mod tests {
|
||||
std::fs::create_dir_all(remote_timeline_dir).expect("creating test dir should work");
|
||||
|
||||
let index_path = test_state.harness.remote_fs_dir.join(
|
||||
remote_index_path(&test_state.harness.tenant_id, &TIMELINE_ID, generation).get_path(),
|
||||
remote_index_path(
|
||||
&test_state.harness.tenant_id,
|
||||
&TIMELINE_ID,
|
||||
shard,
|
||||
generation,
|
||||
)
|
||||
.get_path(),
|
||||
);
|
||||
eprintln!("Writing {index_path}");
|
||||
std::fs::write(&index_path, index_part_bytes).unwrap();
|
||||
@@ -2066,7 +2104,12 @@ mod tests {
|
||||
|
||||
// Simple case: we are in generation N, load the index from generation N - 1
|
||||
let generation_n = 5;
|
||||
let injected = inject_index_part(&test_state, Generation::new(generation_n - 1)).await;
|
||||
let injected = inject_index_part(
|
||||
&test_state,
|
||||
Generation::new(generation_n - 1),
|
||||
ShardIndex::unsharded(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_got_index_part(&test_state, Generation::new(generation_n), &injected).await;
|
||||
|
||||
@@ -2084,22 +2127,34 @@ mod tests {
|
||||
|
||||
// A generation-less IndexPart exists in the bucket, we should find it
|
||||
let generation_n = 5;
|
||||
let injected_none = inject_index_part(&test_state, Generation::none()).await;
|
||||
let injected_none =
|
||||
inject_index_part(&test_state, Generation::none(), ShardIndex::unsharded()).await;
|
||||
assert_got_index_part(&test_state, Generation::new(generation_n), &injected_none).await;
|
||||
|
||||
// If a more recent-than-none generation exists, we should prefer to load that
|
||||
let injected_1 = inject_index_part(&test_state, Generation::new(1)).await;
|
||||
let injected_1 =
|
||||
inject_index_part(&test_state, Generation::new(1), ShardIndex::unsharded()).await;
|
||||
assert_got_index_part(&test_state, Generation::new(generation_n), &injected_1).await;
|
||||
|
||||
// If a more-recent-than-me generation exists, we should ignore it.
|
||||
let _injected_10 = inject_index_part(&test_state, Generation::new(10)).await;
|
||||
let _injected_10 =
|
||||
inject_index_part(&test_state, Generation::new(10), ShardIndex::unsharded()).await;
|
||||
assert_got_index_part(&test_state, Generation::new(generation_n), &injected_1).await;
|
||||
|
||||
// If a directly previous generation exists, _and_ an index exists in my own
|
||||
// generation, I should prefer my own generation.
|
||||
let _injected_prev =
|
||||
inject_index_part(&test_state, Generation::new(generation_n - 1)).await;
|
||||
let injected_current = inject_index_part(&test_state, Generation::new(generation_n)).await;
|
||||
let _injected_prev = inject_index_part(
|
||||
&test_state,
|
||||
Generation::new(generation_n - 1),
|
||||
ShardIndex::unsharded(),
|
||||
)
|
||||
.await;
|
||||
let injected_current = inject_index_part(
|
||||
&test_state,
|
||||
Generation::new(generation_n),
|
||||
ShardIndex::unsharded(),
|
||||
)
|
||||
.await;
|
||||
assert_got_index_part(
|
||||
&test_state,
|
||||
Generation::new(generation_n),
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use camino::Utf8Path;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -53,6 +54,7 @@ pub async fn download_layer_file<'a>(
|
||||
let remote_path = remote_layer_path(
|
||||
&tenant_id,
|
||||
&timeline_id,
|
||||
layer_metadata.shard,
|
||||
layer_file_name,
|
||||
layer_metadata.generation,
|
||||
);
|
||||
@@ -213,10 +215,11 @@ async fn do_download_index_part(
|
||||
storage: &GenericRemoteStorage,
|
||||
tenant_id: &TenantId,
|
||||
timeline_id: &TimelineId,
|
||||
shard: ShardIndex,
|
||||
index_generation: Generation,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<IndexPart, DownloadError> {
|
||||
let remote_path = remote_index_path(tenant_id, timeline_id, index_generation);
|
||||
let remote_path = remote_index_path(tenant_id, timeline_id, shard, index_generation);
|
||||
|
||||
let index_part_bytes = download_retry_forever(
|
||||
|| async {
|
||||
@@ -254,6 +257,7 @@ pub(super) async fn download_index_part(
|
||||
storage: &GenericRemoteStorage,
|
||||
tenant_id: &TenantId,
|
||||
timeline_id: &TimelineId,
|
||||
shard: ShardIndex,
|
||||
my_generation: Generation,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<IndexPart, DownloadError> {
|
||||
@@ -261,8 +265,15 @@ pub(super) async fn download_index_part(
|
||||
|
||||
if my_generation.is_none() {
|
||||
// Operating without generations: just fetch the generation-less path
|
||||
return do_download_index_part(storage, tenant_id, timeline_id, my_generation, cancel)
|
||||
.await;
|
||||
return do_download_index_part(
|
||||
storage,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
shard,
|
||||
my_generation,
|
||||
cancel,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Stale case: If we were intentionally attached in a stale generation, there may already be a remote
|
||||
@@ -273,6 +284,7 @@ pub(super) async fn download_index_part(
|
||||
storage,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
shard,
|
||||
my_generation,
|
||||
cancel.clone(),
|
||||
)
|
||||
@@ -300,6 +312,7 @@ pub(super) async fn download_index_part(
|
||||
storage,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
shard,
|
||||
my_generation.previous(),
|
||||
cancel.clone(),
|
||||
)
|
||||
@@ -320,8 +333,9 @@ pub(super) async fn download_index_part(
|
||||
}
|
||||
|
||||
// General case/fallback: if there is no index at my_generation or prev_generation, then list all index_part.json
|
||||
// objects, and select the highest one with a generation <= my_generation.
|
||||
let index_prefix = remote_index_path(tenant_id, timeline_id, Generation::none());
|
||||
// objects, and select the highest one with a generation <= my_generation. Constructing the prefix is equivalent
|
||||
// to constructing a full index path with no generation, because the generation is a suffix.
|
||||
let index_prefix = remote_index_path(tenant_id, timeline_id, shard, Generation::none());
|
||||
let indices = backoff::retry(
|
||||
|| async { storage.list_files(Some(&index_prefix)).await },
|
||||
|_| false,
|
||||
@@ -347,14 +361,21 @@ pub(super) async fn download_index_part(
|
||||
match max_previous_generation {
|
||||
Some(g) => {
|
||||
tracing::debug!("Found index_part in generation {g:?}");
|
||||
do_download_index_part(storage, tenant_id, timeline_id, g, cancel).await
|
||||
do_download_index_part(storage, tenant_id, timeline_id, shard, g, cancel).await
|
||||
}
|
||||
None => {
|
||||
// Migration from legacy pre-generation state: we have a generation but no prior
|
||||
// attached pageservers did. Try to load from a no-generation path.
|
||||
tracing::info!("No index_part.json* found");
|
||||
do_download_index_part(storage, tenant_id, timeline_id, Generation::none(), cancel)
|
||||
.await
|
||||
do_download_index_part(
|
||||
storage,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
shard,
|
||||
Generation::none(),
|
||||
cancel,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use crate::tenant::metadata::TimelineMetadata;
|
||||
use crate::tenant::storage_layer::LayerFileName;
|
||||
use crate::tenant::upload_queue::UploadQueueInitialized;
|
||||
use crate::tenant::Generation;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
@@ -25,6 +26,8 @@ pub struct LayerFileMetadata {
|
||||
file_size: u64,
|
||||
|
||||
pub(crate) generation: Generation,
|
||||
|
||||
pub(crate) shard: ShardIndex,
|
||||
}
|
||||
|
||||
impl From<&'_ IndexLayerMetadata> for LayerFileMetadata {
|
||||
@@ -32,15 +35,17 @@ impl From<&'_ IndexLayerMetadata> for LayerFileMetadata {
|
||||
LayerFileMetadata {
|
||||
file_size: other.file_size,
|
||||
generation: other.generation,
|
||||
shard: other.shard,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LayerFileMetadata {
|
||||
pub fn new(file_size: u64, generation: Generation) -> Self {
|
||||
pub fn new(file_size: u64, generation: Generation, shard: ShardIndex) -> Self {
|
||||
LayerFileMetadata {
|
||||
file_size,
|
||||
generation,
|
||||
shard,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,6 +166,10 @@ pub struct IndexLayerMetadata {
|
||||
#[serde(default = "Generation::none")]
|
||||
#[serde(skip_serializing_if = "Generation::is_none")]
|
||||
pub generation: Generation,
|
||||
|
||||
#[serde(default = "ShardIndex::unsharded")]
|
||||
#[serde(skip_serializing_if = "ShardIndex::is_unsharded")]
|
||||
pub shard: ShardIndex,
|
||||
}
|
||||
|
||||
impl From<LayerFileMetadata> for IndexLayerMetadata {
|
||||
@@ -168,6 +177,7 @@ impl From<LayerFileMetadata> for IndexLayerMetadata {
|
||||
IndexLayerMetadata {
|
||||
file_size: other.file_size,
|
||||
generation: other.generation,
|
||||
shard: other.shard,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,13 +205,15 @@ mod tests {
|
||||
layer_metadata: HashMap::from([
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
|
||||
file_size: 25600000,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
}),
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
|
||||
// serde_json should always parse this but this might be a double with jq for
|
||||
// example.
|
||||
file_size: 9007199254741001,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
})
|
||||
]),
|
||||
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
|
||||
@@ -233,13 +245,15 @@ mod tests {
|
||||
layer_metadata: HashMap::from([
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
|
||||
file_size: 25600000,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
}),
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
|
||||
// serde_json should always parse this but this might be a double with jq for
|
||||
// example.
|
||||
file_size: 9007199254741001,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
})
|
||||
]),
|
||||
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
|
||||
@@ -272,13 +286,15 @@ mod tests {
|
||||
layer_metadata: HashMap::from([
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
|
||||
file_size: 25600000,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
}),
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
|
||||
// serde_json should always parse this but this might be a double with jq for
|
||||
// example.
|
||||
file_size: 9007199254741001,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
})
|
||||
]),
|
||||
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
|
||||
@@ -354,19 +370,21 @@ mod tests {
|
||||
layer_metadata: HashMap::from([
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
|
||||
file_size: 25600000,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
}),
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
|
||||
// serde_json should always parse this but this might be a double with jq for
|
||||
// example.
|
||||
file_size: 9007199254741001,
|
||||
generation: Generation::none()
|
||||
generation: Generation::none(),
|
||||
shard: ShardIndex::unsharded()
|
||||
})
|
||||
]),
|
||||
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
|
||||
metadata: TimelineMetadata::from_bytes(&[113,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]).unwrap(),
|
||||
deleted_at: Some(chrono::NaiveDateTime::parse_from_str(
|
||||
"2023-07-31T09:00:00.123000000", "%Y-%m-%dT%H:%M:%S.%f").unwrap())
|
||||
"2023-07-31T09:00:00.123000000", "%Y-%m-%dT%H:%M:%S.%f").unwrap()),
|
||||
};
|
||||
|
||||
let part = IndexPart::from_s3_bytes(example.as_bytes()).unwrap();
|
||||
|
||||
@@ -4,6 +4,7 @@ use anyhow::{bail, Context};
|
||||
use bytes::Bytes;
|
||||
use camino::Utf8Path;
|
||||
use fail::fail_point;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use std::io::ErrorKind;
|
||||
use tokio::fs;
|
||||
|
||||
@@ -26,6 +27,7 @@ pub(super) async fn upload_index_part<'a>(
|
||||
storage: &'a GenericRemoteStorage,
|
||||
tenant_id: &TenantId,
|
||||
timeline_id: &TimelineId,
|
||||
shard: ShardIndex,
|
||||
generation: Generation,
|
||||
index_part: &'a IndexPart,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -42,7 +44,7 @@ pub(super) async fn upload_index_part<'a>(
|
||||
let index_part_size = index_part_bytes.len();
|
||||
let index_part_bytes = tokio::io::BufReader::new(std::io::Cursor::new(index_part_bytes));
|
||||
|
||||
let remote_path = remote_index_path(tenant_id, timeline_id, generation);
|
||||
let remote_path = remote_index_path(tenant_id, timeline_id, shard, generation);
|
||||
storage
|
||||
.upload_storage_object(Box::new(index_part_bytes), index_part_size, &remote_path)
|
||||
.await
|
||||
|
||||
@@ -3,6 +3,7 @@ use camino::{Utf8Path, Utf8PathBuf};
|
||||
use pageserver_api::models::{
|
||||
HistoricLayerInfo, LayerAccessKind, LayerResidenceEventReason, LayerResidenceStatus,
|
||||
};
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use std::ops::Range;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Weak};
|
||||
@@ -96,6 +97,7 @@ impl Layer {
|
||||
desc,
|
||||
None,
|
||||
metadata.generation,
|
||||
metadata.shard,
|
||||
)));
|
||||
|
||||
debug_assert!(owner.0.needs_download_blocking().unwrap().is_some());
|
||||
@@ -136,6 +138,7 @@ impl Layer {
|
||||
desc,
|
||||
Some(inner),
|
||||
metadata.generation,
|
||||
metadata.shard,
|
||||
)
|
||||
}));
|
||||
|
||||
@@ -179,6 +182,7 @@ impl Layer {
|
||||
desc,
|
||||
Some(inner),
|
||||
timeline.generation,
|
||||
timeline.get_shard_index(),
|
||||
)
|
||||
}));
|
||||
|
||||
@@ -426,6 +430,15 @@ struct LayerInner {
|
||||
/// For loaded layers (resident or evicted) this comes from [`LayerFileMetadata::generation`],
|
||||
/// for created layers from [`Timeline::generation`].
|
||||
generation: Generation,
|
||||
|
||||
/// The shard of this Layer.
|
||||
///
|
||||
/// For layers created in this process, this will always be the [`ShardIndex`] of the
|
||||
/// current `ShardIdentity`` (TODO: add link once it's introduced).
|
||||
///
|
||||
/// For loaded layers, this may be some other value if the tenant has undergone
|
||||
/// a shard split since the layer was originally written.
|
||||
shard: ShardIndex,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LayerInner {
|
||||
@@ -459,9 +472,9 @@ impl Drop for LayerInner {
|
||||
|
||||
let path = std::mem::take(&mut self.path);
|
||||
let file_name = self.layer_desc().filename();
|
||||
let gen = self.generation;
|
||||
let file_size = self.layer_desc().file_size;
|
||||
let timeline = self.timeline.clone();
|
||||
let meta = self.metadata();
|
||||
|
||||
crate::task_mgr::BACKGROUND_RUNTIME.spawn_blocking(move || {
|
||||
let _g = span.entered();
|
||||
@@ -489,7 +502,7 @@ impl Drop for LayerInner {
|
||||
timeline.metrics.resident_physical_size_sub(file_size);
|
||||
}
|
||||
if let Some(remote_client) = timeline.remote_client.as_ref() {
|
||||
let res = remote_client.schedule_deletion_of_unlinked(vec![(file_name, gen)]);
|
||||
let res = remote_client.schedule_deletion_of_unlinked(vec![(file_name, meta)]);
|
||||
|
||||
if let Err(e) = res {
|
||||
// test_timeline_deletion_with_files_stuck_in_upload_queue is good at
|
||||
@@ -523,6 +536,7 @@ impl LayerInner {
|
||||
desc: PersistentLayerDesc,
|
||||
downloaded: Option<Arc<DownloadedLayer>>,
|
||||
generation: Generation,
|
||||
shard: ShardIndex,
|
||||
) -> Self {
|
||||
let path = conf
|
||||
.timeline_path(&timeline.tenant_id, &timeline.timeline_id)
|
||||
@@ -550,6 +564,7 @@ impl LayerInner {
|
||||
status: tokio::sync::broadcast::channel(1).0,
|
||||
consecutive_failures: AtomicUsize::new(0),
|
||||
generation,
|
||||
shard,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1077,7 +1092,7 @@ impl LayerInner {
|
||||
}
|
||||
|
||||
fn metadata(&self) -> LayerFileMetadata {
|
||||
LayerFileMetadata::new(self.desc.file_size, self.generation)
|
||||
LayerFileMetadata::new(self.desc.file_size, self.generation, self.shard)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ pub(crate) enum BackgroundLoopKind {
|
||||
Eviction,
|
||||
ConsumptionMetricsCollectMetrics,
|
||||
ConsumptionMetricsSyntheticSizeWorker,
|
||||
InitialLogicalSizeCalculation,
|
||||
}
|
||||
|
||||
impl BackgroundLoopKind {
|
||||
|
||||
@@ -62,6 +62,7 @@ use crate::pgdatadir_mapping::{is_rel_fsm_block_key, is_rel_vm_block_key};
|
||||
use crate::pgdatadir_mapping::{BlockNumber, CalculateLogicalSizeError};
|
||||
use crate::tenant::config::{EvictionPolicy, TenantConfOpt};
|
||||
use pageserver_api::reltag::RelTag;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use postgres_ffi::to_pg_timestamp;
|
||||
@@ -1597,6 +1598,7 @@ impl Timeline {
|
||||
|
||||
// Copy to move into the task we're about to spawn
|
||||
let generation = self.generation;
|
||||
let shard = self.get_shard_index();
|
||||
let this = self.myself.upgrade().expect("&self method holds the arc");
|
||||
|
||||
let (loaded_layers, needs_cleanup, total_physical_size) = tokio::task::spawn_blocking({
|
||||
@@ -1645,6 +1647,7 @@ impl Timeline {
|
||||
index_part.as_ref(),
|
||||
disk_consistent_lsn,
|
||||
generation,
|
||||
shard,
|
||||
);
|
||||
|
||||
let mut loaded_layers = Vec::new();
|
||||
@@ -1822,6 +1825,29 @@ impl Timeline {
|
||||
// delay will be terminated by a timeout regardless.
|
||||
let _completion = { self_clone.initial_logical_size_attempt.lock().expect("unexpected initial_logical_size_attempt poisoned").take() };
|
||||
|
||||
// In prod, initial logical size calucalation is spawned either by
|
||||
// WalReceiverConnectionHandler if the timeline is active according to storage broker,
|
||||
// or by the first consumption metrics worker (MetricsCollection).
|
||||
// The latter runs every `metric_collection_interval` and checkpoints to disk, i.e.,
|
||||
// if pageserver gets restarted, the consumption metrics worker will resume waiting
|
||||
// for the correct remaining time, as if the pageserver had not been restarted.
|
||||
//
|
||||
// FIXME: with the current code, walreceiver requests would also hit this semaphore
|
||||
// and get queued behind other background operations. That's bad because walreceiver_connection
|
||||
// will push the not-precise value as `current_timeline_size` in the `PageserverFeedback`
|
||||
// while this calculation is stuck.
|
||||
// We need some way to priority-boost the initial size calculation if walreceiver is asking.
|
||||
// Or, should we maybe revisit the use of logical size in `PageserverFeedback`?
|
||||
// It seems broken the way it is.
|
||||
//
|
||||
// Example query to show different causes of initial size calculation spawning:
|
||||
//
|
||||
// https://neonprod.grafana.net/explore?panes=%7B%22wSx%22:%7B%22datasource%22:%22grafanacloud-logs%22,%22queries%22:%5B%7B%22refId%22:%22A%22,%22expr%22:%22sum%20by%20%28task_kind%29%20%28count_over_time%28%7Bneon_service%3D%5C%22pageserver%5C%22,%20neon_region%3D%5C%22us-west-2%5C%22%7D%20%7C%3D%20%60logical%20size%20computation%20from%20context%20of%20task%20kind%60%20%7C%20regexp%20%60logical%20size%20computation%20from%20context%20of%20task%20kind%20%28%3FP%3Ctask_kind%3E.%2A%29%60%20%5B1m%5D%29%29%22,%22queryType%22:%22range%22,%22datasource%22:%7B%22type%22:%22loki%22,%22uid%22:%22grafanacloud-logs%22%7D,%22editorMode%22:%22code%22,%22step%22:%221m%22%7D%5D,%22range%22:%7B%22from%22:%221700637500615%22,%22to%22:%221700639648743%22%7D%7D%7D&schemaVersion=1&orgId=1
|
||||
let _permit = match crate::tenant::tasks::concurrent_background_tasks_rate_limit(BackgroundLoopKind::InitialLogicalSizeCalculation,&background_ctx, &cancel).await {
|
||||
Ok(permit) => permit,
|
||||
Err(RateLimitError::Cancelled) => return Ok(()),
|
||||
};
|
||||
|
||||
let calculated_size = match self_clone
|
||||
.logical_size_calculation_task(lsn, LogicalSizeCalculationCause::Initial, &background_ctx)
|
||||
.await
|
||||
@@ -4364,6 +4390,11 @@ impl Timeline {
|
||||
resident_layers,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_shard_index(&self) -> ShardIndex {
|
||||
// TODO: carry this on the struct
|
||||
ShardIndex::unsharded()
|
||||
}
|
||||
}
|
||||
|
||||
type TraversalPathItem = (
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::{
|
||||
};
|
||||
use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use std::{collections::HashMap, str::FromStr};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
@@ -107,6 +108,7 @@ pub(super) fn reconcile(
|
||||
index_part: Option<&IndexPart>,
|
||||
disk_consistent_lsn: Lsn,
|
||||
generation: Generation,
|
||||
shard: ShardIndex,
|
||||
) -> Vec<(LayerFileName, Result<Decision, DismissedLayer>)> {
|
||||
use Decision::*;
|
||||
|
||||
@@ -118,10 +120,13 @@ pub(super) fn reconcile(
|
||||
.map(|(name, file_size)| {
|
||||
(
|
||||
name,
|
||||
// The generation here will be corrected to match IndexPart in the merge below, unless
|
||||
// The generation and shard here will be corrected to match IndexPart in the merge below, unless
|
||||
// it is not in IndexPart, in which case using our current generation makes sense
|
||||
// because it will be uploaded in this generation.
|
||||
(Some(LayerFileMetadata::new(file_size, generation)), None),
|
||||
(
|
||||
Some(LayerFileMetadata::new(file_size, generation, shard)),
|
||||
None,
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect::<Collected>();
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use super::storage_layer::LayerFileName;
|
||||
use super::storage_layer::ResidentLayer;
|
||||
use super::Generation;
|
||||
use crate::tenant::metadata::TimelineMetadata;
|
||||
use crate::tenant::remote_timeline_client::index::IndexPart;
|
||||
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
|
||||
@@ -15,6 +14,9 @@ use utils::lsn::AtomicLsn;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
use utils::generation::Generation;
|
||||
|
||||
// clippy warns that Uninitialized is much smaller than Initialized, which wastes
|
||||
// memory for Uninitialized variants. Doesn't matter in practice, there are not
|
||||
// that many upload queues in a running pageserver, and most of them are initialized
|
||||
@@ -232,7 +234,7 @@ pub(crate) struct UploadTask {
|
||||
/// for timeline deletion, which skips this queue and goes directly to DeletionQueue.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Delete {
|
||||
pub(crate) layers: Vec<(LayerFileName, Generation)>,
|
||||
pub(crate) layers: Vec<(LayerFileName, LayerFileMetadata)>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "storage/buf_internals.h"
|
||||
#include "storage/lwlock.h"
|
||||
#include "storage/ipc.h"
|
||||
#include "storage/pg_shmem.h"
|
||||
#include "c.h"
|
||||
#include "postmaster/interrupt.h"
|
||||
|
||||
@@ -87,6 +88,12 @@ bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) =
|
||||
static bool pageserver_flush(void);
|
||||
static void pageserver_disconnect(void);
|
||||
|
||||
static bool
|
||||
PagestoreShmemIsValid()
|
||||
{
|
||||
return pagestore_shared && UsedShmemSegAddr;
|
||||
}
|
||||
|
||||
static bool
|
||||
CheckPageserverConnstring(char **newval, void **extra, GucSource source)
|
||||
{
|
||||
@@ -96,7 +103,7 @@ CheckPageserverConnstring(char **newval, void **extra, GucSource source)
|
||||
static void
|
||||
AssignPageserverConnstring(const char *newval, void *extra)
|
||||
{
|
||||
if(!pagestore_shared)
|
||||
if(!PagestoreShmemIsValid())
|
||||
return;
|
||||
LWLockAcquire(pagestore_shared->lock, LW_EXCLUSIVE);
|
||||
strlcpy(pagestore_shared->pageserver_connstring, newval, MAX_PAGESERVER_CONNSTRING_SIZE);
|
||||
@@ -107,7 +114,7 @@ AssignPageserverConnstring(const char *newval, void *extra)
|
||||
static bool
|
||||
CheckConnstringUpdated()
|
||||
{
|
||||
if(!pagestore_shared)
|
||||
if(!PagestoreShmemIsValid())
|
||||
return false;
|
||||
return pagestore_local_counter < pg_atomic_read_u64(&pagestore_shared->update_counter);
|
||||
}
|
||||
@@ -115,7 +122,7 @@ CheckConnstringUpdated()
|
||||
static void
|
||||
ReloadConnstring()
|
||||
{
|
||||
if(!pagestore_shared)
|
||||
if(!PagestoreShmemIsValid())
|
||||
return;
|
||||
LWLockAcquire(pagestore_shared->lock, LW_SHARED);
|
||||
strlcpy(local_pageserver_connstring, pagestore_shared->pageserver_connstring, sizeof(local_pageserver_connstring));
|
||||
|
||||
@@ -2,3 +2,4 @@
|
||||
comment = 'cloud storage for PostgreSQL'
|
||||
default_version = '1.1'
|
||||
module_pathname = '$libdir/neon'
|
||||
relocatable = true
|
||||
|
||||
@@ -76,3 +76,4 @@ tokio-util.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
config::AuthenticationConfig,
|
||||
@@ -131,7 +132,7 @@ async fn auth_quirks_creds(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
@@ -165,7 +166,7 @@ async fn auth_quirks(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
@@ -241,7 +242,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
pub async fn authenticate(
|
||||
&mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{
|
||||
console::{self, AuthInfo, ConsoleReqExtra},
|
||||
proxy::LatencyTimer,
|
||||
sasl, scram,
|
||||
stream::PqStream,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -15,7 +15,7 @@ pub(super) async fn authenticate(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::{AuthSuccess, ComputeCredentials};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
proxy::LatencyTimer,
|
||||
stream,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -12,7 +12,7 @@ use tracing::{info, warn};
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn cleartext_hack(
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
@@ -37,7 +37,7 @@ pub async fn cleartext_hack(
|
||||
/// Very similar to [`cleartext_hack`], but there's a specific password format.
|
||||
pub async fn password_hack(
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{sasl, scram, stream::PqStream};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
sasl, scram,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self) -> BeMessage<'_>;
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
@@ -21,8 +26,14 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
|
||||
if channel_binding {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
} else {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
scram::METHODS_WITHOUT_PLUS,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +43,7 @@ pub struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
@@ -43,37 +54,44 @@ pub struct CleartextPassword;
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
pub struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream>,
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<S>) -> Self {
|
||||
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream.write_message(&method.first_message()).await?;
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
tls_server_end_point: self.tls_server_end_point,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -123,9 +141,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let secret = self.state.0;
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(secret, rand::random, None))
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
self.tls_server_end_point,
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(outcome)
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
@@ -65,7 +67,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
|
||||
|
||||
// Configure TLS
|
||||
let tls_config: Arc<rustls::ServerConfig> = match (
|
||||
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
@@ -89,16 +91,22 @@ async fn main() -> anyhow::Result<()> {
|
||||
))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
.collect_vec()
|
||||
};
|
||||
|
||||
rustls::ServerConfig::builder()
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into()
|
||||
.into();
|
||||
|
||||
(tls_config, tls_server_end_point)
|
||||
}
|
||||
_ => bail!("tls-key and tls-cert must be specified"),
|
||||
};
|
||||
@@ -113,6 +121,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let main = tokio::spawn(task_main(
|
||||
Arc::new(destination),
|
||||
tls_config,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
@@ -134,6 +143,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -159,7 +169,7 @@ async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
handle_client(dest_suffix, tls_config, socket).await
|
||||
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -207,6 +217,7 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
@@ -231,7 +242,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
Ok(raw.upgrade(tls_config).await?)
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(raw.upgrade(tls_config).await?),
|
||||
tls_server_end_point,
|
||||
})
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
@@ -246,9 +261,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let tls_stream = ssl_handshake(stream, tls_config).await?;
|
||||
let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use crate::auth;
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use rustls::sign;
|
||||
use rustls::{sign, Certificate, PrivateKey};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
@@ -27,6 +30,7 @@ pub struct MetricCollectionConfig {
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: Option<HashSet<String>>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
pub struct HttpConfig {
|
||||
@@ -52,7 +56,7 @@ pub fn configure_tls(
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert(key_path, cert_path, true)?;
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
@@ -64,7 +68,7 @@ pub fn configure_tls(
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert(
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
@@ -76,35 +80,97 @@ pub fn configure_tls(
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(Arc::new(cert_resolver))
|
||||
.with_cert_resolver(cert_resolver.clone())
|
||||
.into();
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
common_names: Some(common_names),
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
struct CertResolver {
|
||||
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
|
||||
default: Option<Arc<rustls::sign::CertifiedKey>>,
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
certs: HashMap::new(),
|
||||
default: None,
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(&cert.0)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] =
|
||||
Sha256::new().chain_update(&cert.0).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_cert(
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
@@ -120,57 +186,65 @@ impl CertResolver {
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
};
|
||||
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context(format!(
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
)
|
||||
})?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
};
|
||||
|
||||
let common_name = {
|
||||
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
|
||||
.context(format!(
|
||||
"Failed to parse PEM object from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.1;
|
||||
let common_name = pem.parse_x509()?.subject().to_string();
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKey,
|
||||
cert_chain: Vec<Certificate>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
let common_name = if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
.context(format!(
|
||||
"Failed to parse common name from certificate at '{cert_path}'."
|
||||
))?;
|
||||
.context("Failed to parse common name from certificate")?;
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some(cert.clone());
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, cert);
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_common_names(&self) -> HashSet<String> {
|
||||
pub fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
@@ -178,15 +252,24 @@ impl CertResolver {
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
_client_hello: rustls::server::ClientHello,
|
||||
client_hello: rustls::server::ClientHello,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = _client_hello.server_name() {
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
|
||||
@@ -470,7 +470,17 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
|
||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.context("missing certificate")?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
@@ -875,7 +885,7 @@ pub async fn proxy_pass(
|
||||
/// Thin connection context.
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
stream: PqStream<Stream<S>>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
@@ -889,7 +899,7 @@ struct Client<'a, S> {
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
stream: PqStream<Stream<S>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
//! A group of high-level tests for connection establishing logic and auth.
|
||||
//!
|
||||
|
||||
mod mitm;
|
||||
|
||||
use super::*;
|
||||
use crate::auth::backend::TestBackend;
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{CachedNodeInfo, NodeInfo};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
common_name: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
@@ -21,7 +25,15 @@ fn generate_certs(
|
||||
params
|
||||
})?;
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
|
||||
let cert = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
|
||||
params.distinguished_name = rcgen::DistinguishedName::new();
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, common_name);
|
||||
params
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
rustls::Certificate(ca.serialize_der()?),
|
||||
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
|
||||
@@ -37,7 +49,14 @@ struct ClientConfig<'a> {
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
self,
|
||||
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
|
||||
) -> anyhow::Result<
|
||||
impl tokio_postgres::tls::TlsConnect<
|
||||
S,
|
||||
Error = impl std::fmt::Debug,
|
||||
Future = impl Send,
|
||||
Stream = RustlsStream<S>,
|
||||
>,
|
||||
> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
@@ -49,20 +68,24 @@ fn generate_tls_config<'a>(
|
||||
hostname: &'a str,
|
||||
common_name: &'a str,
|
||||
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
|
||||
let (ca, cert, key) = generate_certs(hostname)?;
|
||||
let (ca, cert, key) = generate_certs(hostname, common_name)?;
|
||||
|
||||
let tls_config = {
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert], key)?
|
||||
.with_single_cert(vec![cert.clone()], key.clone())?
|
||||
.into();
|
||||
|
||||
let common_names = Some([common_name.to_owned()].iter().cloned().collect());
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
let common_names = Some(cert_resolver.get_common_names());
|
||||
|
||||
TlsConfig {
|
||||
config,
|
||||
common_names,
|
||||
cert_resolver: Arc::new(cert_resolver),
|
||||
}
|
||||
};
|
||||
|
||||
@@ -253,6 +276,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
@@ -263,6 +287,30 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
257
proxy/src/proxy/tests/mitm.rs
Normal file
257
proxy/src/proxy/tests/mitm.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Man-in-the-middle tests
|
||||
//!
|
||||
//! Channel binding should prevent a proxy server
|
||||
//! - that has access to create valid certificates -
|
||||
//! from controlling the TLS connection.
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::*;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
enum Intercept {
|
||||
None,
|
||||
Methods,
|
||||
SASLResponse,
|
||||
}
|
||||
|
||||
async fn proxy_mitm(
|
||||
intercept: Intercept,
|
||||
) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
|
||||
let (end_server1, client1) = tokio::io::duplex(1024);
|
||||
let (server2, end_client2) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config1, server_config1) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
let (client_config2, server_config2) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// begin handshake with end_server
|
||||
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
|
||||
// process handshake with end_client
|
||||
let (end_client, startup) =
|
||||
handshake(client1, Some(&server_config1), &CancelMap::default())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
assert!(buf.is_empty());
|
||||
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
|
||||
|
||||
// give the end_server the startup parameters
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(startup.iter(), &mut buf).unwrap();
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
|
||||
// proxy messages between end_client and end_server
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = end_server.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL and return only SCRAM-SHA-256 ;)
|
||||
if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
|
||||
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_client.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
message = end_client.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
|
||||
if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
|
||||
let sasl_message = &message[1+4+19+4..];
|
||||
let mut new_message = b"n,,".to_vec();
|
||||
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
|
||||
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_server.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
else => { break }
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(end_server1, end_client2, client_config1, server_config2)
|
||||
}
|
||||
|
||||
/// taken from tokio-postgres
|
||||
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
T::Error: Debug,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::ssl_request(&mut buf);
|
||||
stream.write_all(&buf).await.unwrap();
|
||||
|
||||
let mut buf = [0];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
if buf[0] != b'S' {
|
||||
panic!("ssl not supported by server");
|
||||
}
|
||||
|
||||
tls.connect(stream).await.unwrap()
|
||||
}
|
||||
|
||||
struct PgFrame;
|
||||
impl Decoder for PgFrame {
|
||||
type Item = Bytes;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if src.len() < 5 {
|
||||
src.reserve(5 - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
|
||||
if src.len() < len {
|
||||
src.reserve(len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(src.split_to(len).freeze()))
|
||||
}
|
||||
}
|
||||
impl Encoder<Bytes> for PgFrame {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
dst.extend_from_slice(&item);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// If the client doesn't support channel bindings, it can be exploited.
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like the client doesn't support channel bindings, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_failure(
|
||||
intercept: Intercept,
|
||||
channel_binding: tokio_postgres::config::ChannelBinding,
|
||||
) -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(channel_binding)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err()
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
let _server_err = proxy
|
||||
.await?
|
||||
.err()
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -36,9 +36,9 @@ impl<'a> ChannelBinding<&'a str> {
|
||||
|
||||
impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
/// Encode channel binding data as base64 for subsequent checks.
|
||||
pub fn encode<E>(
|
||||
pub fn encode<'a, E>(
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
@@ -51,12 +51,11 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
"eSws".into()
|
||||
}
|
||||
Required(mode) => {
|
||||
let msg = format!(
|
||||
"p={mode},,{data}",
|
||||
mode = mode,
|
||||
data = get_cbind_data(mode)?
|
||||
);
|
||||
base64::encode(msg).into()
|
||||
use std::io::Write;
|
||||
let mut cbind_input = vec![];
|
||||
write!(&mut cbind_input, "p={mode},,",).unwrap();
|
||||
cbind_input.extend_from_slice(get_cbind_data(mode)?);
|
||||
base64::encode(&cbind_input).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -77,7 +76,7 @@ mod tests {
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -22,9 +22,12 @@ pub use secret::ServerSecret;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// TODO: add SCRAM-SHA-256-PLUS
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
|
||||
/// A list of supported SCRAM methods.
|
||||
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
|
||||
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
|
||||
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
|
||||
|
||||
/// Decode base64 into array without any heap allocations
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
@@ -80,7 +83,11 @@ mod tests {
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
let mut exchange = Exchange::new(&secret, || NONCE, None);
|
||||
let mut exchange = Exchange::new(
|
||||
&secret,
|
||||
|| NONCE,
|
||||
crate::config::TlsServerEndPoint::Undefined,
|
||||
);
|
||||
|
||||
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
|
||||
|
||||
@@ -5,9 +5,11 @@ use super::messages::{
|
||||
};
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::config;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
struct TlsServerEndPoint;
|
||||
|
||||
impl std::fmt::Display for TlsServerEndPoint {
|
||||
@@ -43,20 +45,20 @@ pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial,
|
||||
secret,
|
||||
nonce,
|
||||
cert_digest,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,6 +73,14 @@ impl sasl::Mechanism for Exchange<'_> {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
|
||||
|
||||
// If the flag is set to "y" and the server supports channel
|
||||
// binding, the server MUST fail authentication
|
||||
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
|
||||
&& self.tls_server_end_point.supported()
|
||||
{
|
||||
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
|
||||
}
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&self.secret.salt_base64,
|
||||
@@ -94,10 +104,11 @@ impl sasl::Mechanism for Exchange<'_> {
|
||||
let client_final_message = ClientFinalMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| {
|
||||
self.cert_digest
|
||||
.map(base64::encode)
|
||||
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point {
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => {
|
||||
Err(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
}
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
@@ -17,7 +18,7 @@ use tokio_rustls::server::TlsStream;
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
framed: Framed<S>,
|
||||
pub(crate) framed: Framed<S>,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
@@ -118,19 +119,21 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
/// NOTE: it should be possible to decompose this object as necessary.
|
||||
#[project = StreamProj]
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { #[pin] raw: S },
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { raw: S },
|
||||
Tls {
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
Tls { #[pin] tls: Box<TlsStream<S>> },
|
||||
}
|
||||
tls: Box<TlsStream<S>>,
|
||||
/// Channel binding parameter
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S: Unpin> Unpin for Stream<S> {}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
@@ -141,7 +144,17 @@ impl<S> Stream<S> {
|
||||
pub fn sni_hostname(&self) -> Option<&str> {
|
||||
match self {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls } => tls.get_ref().1.server_name(),
|
||||
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
match self {
|
||||
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
|
||||
Stream::Tls {
|
||||
tls_server_end_point,
|
||||
..
|
||||
} => *tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,12 +171,9 @@ pub enum StreamUpgradeError {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
|
||||
match self {
|
||||
Stream::Raw { raw } => {
|
||||
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
|
||||
Ok(Stream::Tls { tls })
|
||||
}
|
||||
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
|
||||
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
|
||||
}
|
||||
}
|
||||
@@ -171,50 +181,46 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_read(context, buf),
|
||||
Tls { tls } => tls.poll_read(context, buf),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_write(context, buf),
|
||||
Tls { tls } => tls.poll_write(context, buf),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_flush(context),
|
||||
Tls { tls } => tls.poll_flush(context),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_shutdown(context),
|
||||
Tls { tls } => tls.poll_shutdown(context),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,8 +434,6 @@ class NeonEnvBuilder:
|
||||
|
||||
# Pageserver remote storage
|
||||
self.pageserver_remote_storage = pageserver_remote_storage
|
||||
# Extensions remote storage
|
||||
self.ext_remote_storage: Optional[S3Storage] = None
|
||||
# Safekeepers remote storage
|
||||
self.sk_remote_storage: Optional[RemoteStorage] = None
|
||||
|
||||
@@ -534,24 +532,6 @@ class NeonEnvBuilder:
|
||||
)
|
||||
self.pageserver_remote_storage = ret
|
||||
|
||||
def enable_extensions_remote_storage(self, kind: RemoteStorageKind):
|
||||
assert self.ext_remote_storage is None, "already configured extensions remote storage"
|
||||
|
||||
# there is an assumption that REAL_S3 for extensions is never
|
||||
# cleaned up these are also special in that they have a hardcoded
|
||||
# bucket and region, which is most likely the same as our normal
|
||||
ext = self._configure_and_create_remote_storage(
|
||||
kind,
|
||||
RemoteStorageUser.EXTENSIONS,
|
||||
bucket_name="neon-dev-extensions-eu-central-1",
|
||||
bucket_region="eu-central-1",
|
||||
)
|
||||
assert isinstance(
|
||||
ext, S3Storage
|
||||
), "unsure why, but only MOCK_S3 and REAL_S3 are currently supported for extensions"
|
||||
ext.cleanup = False
|
||||
self.ext_remote_storage = ext
|
||||
|
||||
def enable_safekeeper_remote_storage(self, kind: RemoteStorageKind):
|
||||
assert self.sk_remote_storage is None, "sk_remote_storage already configured"
|
||||
|
||||
@@ -608,8 +588,7 @@ class NeonEnvBuilder:
|
||||
directory_to_clean.rmdir()
|
||||
|
||||
def cleanup_remote_storage(self):
|
||||
# extensions are currently not cleaned up, disabled when creating
|
||||
for x in [self.pageserver_remote_storage, self.ext_remote_storage, self.sk_remote_storage]:
|
||||
for x in [self.pageserver_remote_storage, self.sk_remote_storage]:
|
||||
if isinstance(x, S3Storage):
|
||||
x.do_cleanup()
|
||||
|
||||
@@ -713,7 +692,6 @@ class NeonEnv:
|
||||
self.pageservers: List[NeonPageserver] = []
|
||||
self.broker = config.broker
|
||||
self.pageserver_remote_storage = config.pageserver_remote_storage
|
||||
self.ext_remote_storage = config.ext_remote_storage
|
||||
self.safekeepers_remote_storage = config.sk_remote_storage
|
||||
self.pg_version = config.pg_version
|
||||
# Binary path for pageserver, safekeeper, etc
|
||||
@@ -1469,12 +1447,7 @@ class NeonCli(AbstractNeonCli):
|
||||
if pageserver_id is not None:
|
||||
args.extend(["--pageserver-id", str(pageserver_id)])
|
||||
|
||||
storage = self.env.ext_remote_storage
|
||||
s3_env_vars = None
|
||||
if isinstance(storage, S3Storage):
|
||||
s3_env_vars = storage.access_env_vars()
|
||||
|
||||
res = self.raw_cli(args, extra_env_vars=s3_env_vars)
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
@@ -1599,7 +1572,7 @@ class NeonAttachmentService:
|
||||
self.running = False
|
||||
return self
|
||||
|
||||
def attach_hook(self, tenant_id: TenantId, pageserver_id: int) -> int:
|
||||
def attach_hook_issue(self, tenant_id: TenantId, pageserver_id: int) -> int:
|
||||
response = requests.post(
|
||||
f"{self.env.control_plane_api}/attach-hook",
|
||||
json={"tenant_id": str(tenant_id), "node_id": pageserver_id},
|
||||
@@ -1609,6 +1582,13 @@ class NeonAttachmentService:
|
||||
assert isinstance(gen, int)
|
||||
return gen
|
||||
|
||||
def attach_hook_drop(self, tenant_id: TenantId):
|
||||
response = requests.post(
|
||||
f"{self.env.control_plane_api}/attach-hook",
|
||||
json={"tenant_id": str(tenant_id), "node_id": None},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def __enter__(self) -> "NeonAttachmentService":
|
||||
return self
|
||||
|
||||
@@ -1808,13 +1788,20 @@ class NeonPageserver(PgProtocol):
|
||||
to call into the pageserver HTTP client.
|
||||
"""
|
||||
if self.env.attachment_service is not None:
|
||||
generation = self.env.attachment_service.attach_hook(tenant_id, self.id)
|
||||
generation = self.env.attachment_service.attach_hook_issue(tenant_id, self.id)
|
||||
else:
|
||||
generation = None
|
||||
|
||||
client = self.http_client()
|
||||
return client.tenant_attach(tenant_id, config, config_null, generation=generation)
|
||||
|
||||
def tenant_detach(self, tenant_id: TenantId):
|
||||
if self.env.attachment_service is not None:
|
||||
self.env.attachment_service.attach_hook_drop(tenant_id)
|
||||
|
||||
client = self.http_client()
|
||||
return client.tenant_detach(tenant_id)
|
||||
|
||||
|
||||
def append_pageserver_param_overrides(
|
||||
params_to_update: List[str],
|
||||
@@ -2582,6 +2569,17 @@ class Endpoint(PgProtocol):
|
||||
with open(config_path, "w") as file:
|
||||
json.dump(dict(data_dict, **kwargs), file, indent=4)
|
||||
|
||||
# Mock the extension part of spec passed from control plane for local testing
|
||||
# endpooint.rs adds content of this file as a part of the spec.json
|
||||
def create_remote_extension_spec(self, spec: dict[str, Any]):
|
||||
"""Create a remote extension spec file for the endpoint."""
|
||||
remote_extensions_spec_path = os.path.join(
|
||||
self.endpoint_path(), "remote_extensions_spec.json"
|
||||
)
|
||||
|
||||
with open(remote_extensions_spec_path, "w") as file:
|
||||
json.dump(spec, file, indent=4)
|
||||
|
||||
def stop(self) -> "Endpoint":
|
||||
"""
|
||||
Stop the Postgres instance if it's running.
|
||||
|
||||
@@ -411,7 +411,6 @@ def check_neon_works(
|
||||
config.initial_tenant = snapshot_config["default_tenant_id"]
|
||||
config.pg_distrib_dir = pg_distrib_dir
|
||||
config.remote_storage = None
|
||||
config.ext_remote_storage = None
|
||||
config.sk_remote_storage = None
|
||||
|
||||
# Use the "target" binaries to launch the storage nodes
|
||||
|
||||
@@ -1,316 +1,137 @@
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
NeonEnvBuilder,
|
||||
)
|
||||
from fixtures.pg_version import PgVersion, skip_on_postgres
|
||||
from fixtures.remote_storage import (
|
||||
RemoteStorageKind,
|
||||
S3Storage,
|
||||
available_s3_storages,
|
||||
)
|
||||
from fixtures.pg_version import PgVersion
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
|
||||
# Cleaning up downloaded files is important for local tests
|
||||
# or else one test could reuse the files from another test or another test run
|
||||
def cleanup(pg_version):
|
||||
PGDIR = Path(f"pg_install/v{pg_version}")
|
||||
# use neon_env_builder_local fixture to override the default neon_env_builder fixture
|
||||
# and use a test-specific pg_install instead of shared one
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_env_builder_local(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
test_output_dir: Path,
|
||||
pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
) -> NeonEnvBuilder:
|
||||
test_local_pginstall = test_output_dir / "pg_install"
|
||||
log.info(f"copy {pg_distrib_dir} to {test_local_pginstall}")
|
||||
shutil.copytree(
|
||||
pg_distrib_dir / pg_version.v_prefixed, test_local_pginstall / pg_version.v_prefixed
|
||||
)
|
||||
|
||||
LIB_DIR = PGDIR / Path("lib/postgresql")
|
||||
cleanup_lib_globs = ["anon*", "postgis*", "pg_buffercache*"]
|
||||
cleanup_lib_glob_paths = [LIB_DIR.glob(x) for x in cleanup_lib_globs]
|
||||
neon_env_builder.pg_distrib_dir = test_local_pginstall
|
||||
log.info(f"local neon_env_builder.pg_distrib_dir: {neon_env_builder.pg_distrib_dir}")
|
||||
|
||||
SHARE_DIR = PGDIR / Path("share/postgresql/extension")
|
||||
cleanup_ext_globs = [
|
||||
"anon*",
|
||||
"address_standardizer*",
|
||||
"postgis*",
|
||||
"pageinspect*",
|
||||
"pg_buffercache*",
|
||||
"pgrouting*",
|
||||
]
|
||||
cleanup_ext_glob_paths = [SHARE_DIR.glob(x) for x in cleanup_ext_globs]
|
||||
|
||||
all_glob_paths = cleanup_lib_glob_paths + cleanup_ext_glob_paths
|
||||
all_cleanup_files = []
|
||||
for file_glob in all_glob_paths:
|
||||
for file in file_glob:
|
||||
all_cleanup_files.append(file)
|
||||
|
||||
for file in all_cleanup_files:
|
||||
try:
|
||||
os.remove(file)
|
||||
log.info(f"removed file {file}")
|
||||
except Exception as err:
|
||||
log.info(
|
||||
f"skipping remove of file {file} because it doesn't exist.\
|
||||
this may be expected or unexpected depending on the test {err}"
|
||||
)
|
||||
|
||||
cleanup_folders = [SHARE_DIR / Path("anon"), PGDIR / Path("download_extensions")]
|
||||
for folder in cleanup_folders:
|
||||
try:
|
||||
shutil.rmtree(folder)
|
||||
log.info(f"removed folder {folder}")
|
||||
except Exception as err:
|
||||
log.info(
|
||||
f"skipping remove of folder {folder} because it doesn't exist.\
|
||||
this may be expected or unexpected depending on the test {err}"
|
||||
)
|
||||
return neon_env_builder
|
||||
|
||||
|
||||
def upload_files(env):
|
||||
log.info("Uploading test files to mock bucket")
|
||||
os.chdir("test_runner/regress/data/extension_test")
|
||||
for path in os.walk("."):
|
||||
prefix, _, files = path
|
||||
for file in files:
|
||||
# the [2:] is to remove the leading "./"
|
||||
full_path = os.path.join(prefix, file)[2:]
|
||||
|
||||
with open(full_path, "rb") as f:
|
||||
log.info(f"UPLOAD {full_path} to ext/{full_path}")
|
||||
assert isinstance(env.pageserver_remote_storage, S3Storage)
|
||||
env.pageserver_remote_storage.client.upload_fileobj(
|
||||
f,
|
||||
env.ext_remote_storage.bucket_name,
|
||||
f"ext/{full_path}",
|
||||
)
|
||||
os.chdir("../../../..")
|
||||
|
||||
|
||||
# Test downloading remote extension.
|
||||
@skip_on_postgres(PgVersion.V16, reason="TODO: PG16 extension building")
|
||||
@pytest.mark.parametrize("remote_storage_kind", available_s3_storages())
|
||||
@pytest.mark.skip(reason="https://github.com/neondatabase/neon/issues/4949")
|
||||
def test_remote_extensions(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
remote_storage_kind: RemoteStorageKind,
|
||||
pg_version: PgVersion,
|
||||
httpserver: HTTPServer,
|
||||
neon_env_builder_local: NeonEnvBuilder,
|
||||
httpserver_listen_address,
|
||||
pg_version,
|
||||
):
|
||||
neon_env_builder.enable_extensions_remote_storage(remote_storage_kind)
|
||||
env = neon_env_builder.init_start()
|
||||
tenant_id, _ = env.neon_cli.create_tenant()
|
||||
env.neon_cli.create_timeline("test_remote_extensions", tenant_id=tenant_id)
|
||||
if pg_version == PgVersion.V16:
|
||||
pytest.skip("TODO: PG16 extension building")
|
||||
|
||||
assert env.ext_remote_storage is not None # satisfy mypy
|
||||
# setup mock http server
|
||||
# that expects request for anon.tar.zst
|
||||
# and returns the requested file
|
||||
(host, port) = httpserver_listen_address
|
||||
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
|
||||
|
||||
# For MOCK_S3 we upload test files.
|
||||
# For REAL_S3 we use the files already in the bucket
|
||||
if remote_storage_kind == RemoteStorageKind.MOCK_S3:
|
||||
upload_files(env)
|
||||
build_tag = os.environ.get("BUILD_TAG", "latest")
|
||||
archive_path = f"{build_tag}/v{pg_version}/extensions/anon.tar.zst"
|
||||
|
||||
# Start a compute node and check that it can download the extensions
|
||||
# and use them to CREATE EXTENSION and LOAD
|
||||
endpoint = env.endpoints.create_start(
|
||||
def endpoint_handler_build_tag(request: Request) -> Response:
|
||||
log.info(f"request: {request}")
|
||||
|
||||
file_name = "anon.tar.zst"
|
||||
file_path = f"test_runner/regress/data/extension_test/5670669815/v{pg_version}/extensions/anon.tar.zst"
|
||||
file_size = os.path.getsize(file_path)
|
||||
fh = open(file_path, "rb")
|
||||
|
||||
return Response(
|
||||
fh,
|
||||
mimetype="application/octet-stream",
|
||||
headers=[
|
||||
("Content-Length", str(file_size)),
|
||||
("Content-Disposition", 'attachment; filename="%s"' % file_name),
|
||||
],
|
||||
direct_passthrough=True,
|
||||
)
|
||||
|
||||
httpserver.expect_request(
|
||||
f"/pg-ext-s3-gateway/{archive_path}", method="GET"
|
||||
).respond_with_handler(endpoint_handler_build_tag)
|
||||
|
||||
# Start a compute node with remote_extension spec
|
||||
# and check that it can download the extensions and use them to CREATE EXTENSION.
|
||||
env = neon_env_builder_local.init_start()
|
||||
env.neon_cli.create_branch("test_remote_extensions")
|
||||
endpoint = env.endpoints.create(
|
||||
"test_remote_extensions",
|
||||
tenant_id=tenant_id,
|
||||
remote_ext_config=env.ext_remote_storage.to_string(),
|
||||
# config_lines=["log_min_messages=debug3"],
|
||||
config_lines=["log_min_messages=debug3"],
|
||||
)
|
||||
|
||||
# mock remote_extensions spec
|
||||
spec: Dict[str, Any] = {
|
||||
"library_index": {
|
||||
"anon": "anon",
|
||||
},
|
||||
"extension_data": {
|
||||
"anon": {
|
||||
"archive_path": "",
|
||||
"control_data": {
|
||||
"anon.control": "# PostgreSQL Anonymizer (anon) extension\ncomment = 'Data anonymization tools'\ndefault_version = '1.1.0'\ndirectory='extension/anon'\nrelocatable = false\nrequires = 'pgcrypto'\nsuperuser = false\nmodule_pathname = '$libdir/anon'\ntrusted = true\n"
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
spec["extension_data"]["anon"]["archive_path"] = archive_path
|
||||
|
||||
endpoint.create_remote_extension_spec(spec)
|
||||
|
||||
endpoint.start(
|
||||
remote_ext_config=extensions_endpoint,
|
||||
)
|
||||
|
||||
# this is expected to fail if there's no pgcrypto extension, that's ok
|
||||
# we just want to check that the extension was downloaded
|
||||
try:
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Check that appropriate control files were downloaded
|
||||
cur.execute("SELECT * FROM pg_available_extensions")
|
||||
all_extensions = [x[0] for x in cur.fetchall()]
|
||||
log.info(all_extensions)
|
||||
assert "anon" in all_extensions
|
||||
# Check that appropriate files were downloaded
|
||||
cur.execute("CREATE EXTENSION anon")
|
||||
res = [x[0] for x in cur.fetchall()]
|
||||
log.info(res)
|
||||
except Exception as err:
|
||||
assert "pgcrypto" in str(err), f"unexpected error creating anon extension {err}"
|
||||
|
||||
# postgis is on real s3 but not mock s3.
|
||||
# it's kind of a big file, would rather not upload to github
|
||||
if remote_storage_kind == RemoteStorageKind.REAL_S3:
|
||||
assert "postgis" in all_extensions
|
||||
# this may fail locally if dependency is missing
|
||||
# we don't really care about the error,
|
||||
# we just want to make sure it downloaded
|
||||
try:
|
||||
cur.execute("CREATE EXTENSION postgis")
|
||||
except Exception as err:
|
||||
log.info(f"(expected) error creating postgis extension: {err}")
|
||||
# we do not check the error, so this is basically a NO-OP
|
||||
# however checking the log you can make sure that it worked
|
||||
# and also get valuable information about how long loading the extension took
|
||||
|
||||
# this is expected to fail on my computer because I don't have the pgcrypto extension
|
||||
try:
|
||||
cur.execute("CREATE EXTENSION anon")
|
||||
except Exception as err:
|
||||
log.info("error creating anon extension")
|
||||
assert "pgcrypto" in str(err), "unexpected error creating anon extension"
|
||||
finally:
|
||||
cleanup(pg_version)
|
||||
httpserver.check()
|
||||
|
||||
|
||||
# Test downloading remote library.
|
||||
@skip_on_postgres(PgVersion.V16, reason="TODO: PG16 extension building")
|
||||
@pytest.mark.parametrize("remote_storage_kind", available_s3_storages())
|
||||
@pytest.mark.skip(reason="https://github.com/neondatabase/neon/issues/4949")
|
||||
def test_remote_library(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
remote_storage_kind: RemoteStorageKind,
|
||||
pg_version: PgVersion,
|
||||
):
|
||||
neon_env_builder.enable_extensions_remote_storage(remote_storage_kind)
|
||||
env = neon_env_builder.init_start()
|
||||
tenant_id, _ = env.neon_cli.create_tenant()
|
||||
env.neon_cli.create_timeline("test_remote_library", tenant_id=tenant_id)
|
||||
|
||||
assert env.ext_remote_storage is not None # satisfy mypy
|
||||
|
||||
# For MOCK_S3 we upload test files.
|
||||
# For REAL_S3 we use the files already in the bucket
|
||||
if remote_storage_kind == RemoteStorageKind.MOCK_S3:
|
||||
upload_files(env)
|
||||
|
||||
# and use them to run LOAD library
|
||||
endpoint = env.endpoints.create_start(
|
||||
"test_remote_library",
|
||||
tenant_id=tenant_id,
|
||||
remote_ext_config=env.ext_remote_storage.to_string(),
|
||||
# config_lines=["log_min_messages=debug3"],
|
||||
)
|
||||
try:
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# try to load library
|
||||
try:
|
||||
cur.execute("LOAD 'anon'")
|
||||
except Exception as err:
|
||||
log.info(f"error loading anon library: {err}")
|
||||
raise AssertionError("unexpected error loading anon library") from err
|
||||
|
||||
# test library which name is different from extension name
|
||||
# this may fail locally if dependency is missing
|
||||
# however, it does successfully download the postgis archive
|
||||
if remote_storage_kind == RemoteStorageKind.REAL_S3:
|
||||
try:
|
||||
cur.execute("LOAD 'postgis_topology-3'")
|
||||
except Exception as err:
|
||||
log.info("error loading postgis_topology-3")
|
||||
assert "No such file or directory" in str(
|
||||
err
|
||||
), "unexpected error loading postgis_topology-3"
|
||||
finally:
|
||||
cleanup(pg_version)
|
||||
|
||||
|
||||
# Here we test a complex extension
|
||||
# which has multiple extensions in one archive
|
||||
# using postgis as an example
|
||||
# @pytest.mark.skipif(
|
||||
# RemoteStorageKind.REAL_S3 not in available_s3_storages(),
|
||||
# reason="skipping test because real s3 not enabled",
|
||||
# )
|
||||
@skip_on_postgres(PgVersion.V16, reason="TODO: PG16 extension building")
|
||||
@pytest.mark.skip(reason="https://github.com/neondatabase/neon/issues/4949")
|
||||
def test_multiple_extensions_one_archive(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
pg_version: PgVersion,
|
||||
):
|
||||
neon_env_builder.enable_extensions_remote_storage(RemoteStorageKind.REAL_S3)
|
||||
env = neon_env_builder.init_start()
|
||||
tenant_id, _ = env.neon_cli.create_tenant()
|
||||
env.neon_cli.create_timeline("test_multiple_extensions_one_archive", tenant_id=tenant_id)
|
||||
|
||||
assert env.ext_remote_storage is not None # satisfy mypy
|
||||
|
||||
endpoint = env.endpoints.create_start(
|
||||
"test_multiple_extensions_one_archive",
|
||||
tenant_id=tenant_id,
|
||||
remote_ext_config=env.ext_remote_storage.to_string(),
|
||||
)
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION address_standardizer;")
|
||||
cur.execute("CREATE EXTENSION address_standardizer_data_us;")
|
||||
# execute query to ensure that it works
|
||||
cur.execute(
|
||||
"SELECT house_num, name, suftype, city, country, state, unit \
|
||||
FROM standardize_address('us_lex', 'us_gaz', 'us_rules', \
|
||||
'One Rust Place, Boston, MA 02109');"
|
||||
)
|
||||
res = cur.fetchall()
|
||||
log.info(res)
|
||||
assert len(res) > 0
|
||||
|
||||
cleanup(pg_version)
|
||||
|
||||
|
||||
# Test that extension is downloaded after endpoint restart,
|
||||
# when the library is used in the query.
|
||||
# TODO
|
||||
# 1. Test downloading remote library.
|
||||
#
|
||||
# 2. Test a complex extension, which has multiple extensions in one archive
|
||||
# using postgis as an example
|
||||
#
|
||||
# 3.Test that extension is downloaded after endpoint restart,
|
||||
# when the library is used in the query.
|
||||
# Run the test with mutliple simultaneous connections to an endpoint.
|
||||
# to ensure that the extension is downloaded only once.
|
||||
#
|
||||
@pytest.mark.skip(reason="https://github.com/neondatabase/neon/issues/4949")
|
||||
def test_extension_download_after_restart(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
pg_version: PgVersion,
|
||||
):
|
||||
# TODO: PG15 + PG16 extension building
|
||||
if "v14" not in pg_version: # test set only has extension built for v14
|
||||
return None
|
||||
|
||||
neon_env_builder.enable_extensions_remote_storage(RemoteStorageKind.MOCK_S3)
|
||||
env = neon_env_builder.init_start()
|
||||
tenant_id, _ = env.neon_cli.create_tenant()
|
||||
env.neon_cli.create_timeline("test_extension_download_after_restart", tenant_id=tenant_id)
|
||||
|
||||
assert env.ext_remote_storage is not None # satisfy mypy
|
||||
|
||||
# For MOCK_S3 we upload test files.
|
||||
upload_files(env)
|
||||
|
||||
endpoint = env.endpoints.create_start(
|
||||
"test_extension_download_after_restart",
|
||||
tenant_id=tenant_id,
|
||||
remote_ext_config=env.ext_remote_storage.to_string(),
|
||||
config_lines=["log_min_messages=debug3"],
|
||||
)
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CREATE extension pg_buffercache;")
|
||||
cur.execute("SELECT * from pg_buffercache;")
|
||||
res = cur.fetchall()
|
||||
assert len(res) > 0
|
||||
log.info(res)
|
||||
|
||||
# shutdown compute node
|
||||
endpoint.stop()
|
||||
# remove extension files locally
|
||||
cleanup(pg_version)
|
||||
|
||||
# spin up compute node again (there are no extension files available, because compute is stateless)
|
||||
endpoint = env.endpoints.create_start(
|
||||
"test_extension_download_after_restart",
|
||||
tenant_id=tenant_id,
|
||||
remote_ext_config=env.ext_remote_storage.to_string(),
|
||||
config_lines=["log_min_messages=debug3"],
|
||||
)
|
||||
|
||||
# connect to compute node and run the query
|
||||
# that will trigger the download of the extension
|
||||
def run_query(endpoint, thread_id: int):
|
||||
log.info("thread_id {%d} starting", thread_id)
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT * from pg_buffercache;")
|
||||
res = cur.fetchall()
|
||||
assert len(res) > 0
|
||||
log.info("thread_id {%d}, res = %s", thread_id, res)
|
||||
|
||||
threads = [threading.Thread(target=run_query, args=(endpoint, i)) for i in range(2)]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
cleanup(pg_version)
|
||||
# 4. Test that private extensions are only downloaded when they are present in the spec.
|
||||
#
|
||||
|
||||
@@ -282,7 +282,7 @@ def test_deferred_deletion(neon_env_builder: NeonEnvBuilder):
|
||||
|
||||
# Now advance the generation in the control plane: subsequent validations
|
||||
# from the running pageserver will fail. No more deletions should happen.
|
||||
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
|
||||
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
|
||||
generate_uploads_and_deletions(env, init=False)
|
||||
|
||||
assert_deletion_queue(ps_http, lambda n: n > 0)
|
||||
@@ -397,7 +397,7 @@ def test_deletion_queue_recovery(
|
||||
if keep_attachment == KeepAttachment.LOSE:
|
||||
some_other_pageserver = 101010
|
||||
assert env.attachment_service is not None
|
||||
env.attachment_service.attach_hook(env.initial_tenant, some_other_pageserver)
|
||||
env.attachment_service.attach_hook_issue(env.initial_tenant, some_other_pageserver)
|
||||
|
||||
env.pageserver.start()
|
||||
|
||||
|
||||
@@ -336,10 +336,15 @@ def test_live_reconfig_get_evictions_low_residence_duration_metric_threshold(
|
||||
):
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
env = neon_env_builder.init_start(
|
||||
initial_tenant_conf={
|
||||
# disable compaction so that it will not download the layer for repartitioning
|
||||
"compaction_period": "0s"
|
||||
}
|
||||
)
|
||||
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
|
||||
|
||||
(tenant_id, timeline_id) = env.neon_cli.create_tenant()
|
||||
(tenant_id, timeline_id) = env.initial_tenant, env.initial_timeline
|
||||
ps_http = env.pageserver.http_client()
|
||||
|
||||
def get_metric():
|
||||
|
||||
Reference in New Issue
Block a user