mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 02:42:56 +00:00
Compare commits
1 Commits
release-48
...
RemoteExte
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02f8650111 |
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -1329,6 +1329,8 @@ dependencies = [
|
||||
"clap",
|
||||
"comfy-table",
|
||||
"compute_api",
|
||||
"diesel",
|
||||
"diesel_migrations",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
@@ -2247,11 +2249,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.8.4"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
|
||||
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
|
||||
dependencies = [
|
||||
"hashbrown 0.14.0",
|
||||
"hashbrown 0.13.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3936,7 +3938,6 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -4078,7 +4079,6 @@ dependencies = [
|
||||
"clap",
|
||||
"consumption_metrics",
|
||||
"dashmap",
|
||||
"env_logger",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hashbrown 0.13.2",
|
||||
@@ -4126,7 +4126,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"smol_str",
|
||||
"socket2 0.5.5",
|
||||
"sync_wrapper",
|
||||
@@ -4145,7 +4144,6 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
"tracing-utils",
|
||||
"url",
|
||||
"urlencoding",
|
||||
"utils",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
@@ -5741,7 +5739,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tokio-epoll-uring"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"nix 0.26.4",
|
||||
@@ -6266,7 +6264,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "uring-common"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d"
|
||||
dependencies = [
|
||||
"io-uring",
|
||||
"libc",
|
||||
@@ -6833,6 +6831,8 @@ dependencies = [
|
||||
"clap",
|
||||
"clap_builder",
|
||||
"crossbeam-utils",
|
||||
"diesel",
|
||||
"diesel_derives",
|
||||
"either",
|
||||
"fail",
|
||||
"futures-channel",
|
||||
|
||||
@@ -80,7 +80,7 @@ futures-core = "0.3"
|
||||
futures-util = "0.3"
|
||||
git-version = "0.3"
|
||||
hashbrown = "0.13"
|
||||
hashlink = "0.8.4"
|
||||
hashlink = "0.8.1"
|
||||
hdrhistogram = "7.5.2"
|
||||
hex = "0.4"
|
||||
hex-literal = "0.4"
|
||||
@@ -171,7 +171,6 @@ tracing-opentelemetry = "0.20.0"
|
||||
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
|
||||
twox-hash = { version = "1.6.3", default-features = false }
|
||||
url = "2.2"
|
||||
urlencoding = "2.1"
|
||||
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
webpki-roots = "0.25"
|
||||
|
||||
@@ -100,11 +100,6 @@ RUN mkdir -p /data/.neon/ && chown -R neon:neon /data/.neon/ \
|
||||
-c "listen_pg_addr='0.0.0.0:6400'" \
|
||||
-c "listen_http_addr='0.0.0.0:9898'"
|
||||
|
||||
# When running a binary that links with libpq, default to using our most recent postgres version. Binaries
|
||||
# that want a particular postgres version will select it explicitly: this is just a default.
|
||||
ENV LD_LIBRARY_PATH /usr/local/v16/lib
|
||||
|
||||
|
||||
VOLUME ["/data"]
|
||||
USER neon
|
||||
EXPOSE 6400
|
||||
|
||||
@@ -135,7 +135,7 @@ WORKDIR /home/nonroot
|
||||
|
||||
# Rust
|
||||
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
|
||||
ENV RUSTC_VERSION=1.76.0
|
||||
ENV RUSTC_VERSION=1.75.0
|
||||
ENV RUSTUP_HOME="/home/nonroot/.rustup"
|
||||
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
|
||||
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
|
||||
|
||||
@@ -639,8 +639,8 @@ FROM build-deps AS pg-anon-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
ENV PATH "/usr/local/pgsql/bin/:$PATH"
|
||||
RUN wget https://github.com/neondatabase/postgresql_anonymizer/archive/refs/tags/neon_1.1.1.tar.gz -O pg_anon.tar.gz && \
|
||||
echo "321ea8d5c1648880aafde850a2c576e4a9e7b9933a34ce272efc839328999fa9 pg_anon.tar.gz" | sha256sum --check && \
|
||||
RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/1.1.0/postgresql_anonymizer-1.1.0.tar.gz -O pg_anon.tar.gz && \
|
||||
echo "08b09d2ff9b962f96c60db7e6f8e79cf7253eb8772516998fc35ece08633d3ad pg_anon.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_anon-src && cd pg_anon-src && tar xvzf ../pg_anon.tar.gz --strip-components=1 -C . && \
|
||||
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
@@ -809,7 +809,6 @@ COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-semver-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
|
||||
COPY --from=pg-anon-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY pgxn/ pgxn/
|
||||
|
||||
RUN make -j $(getconf _NPROCESSORS_ONLN) \
|
||||
|
||||
2
NOTICE
2
NOTICE
@@ -1,5 +1,5 @@
|
||||
Neon
|
||||
Copyright 2022 - 2024 Neon Inc.
|
||||
Copyright 2022 Neon Inc.
|
||||
|
||||
The PostgreSQL submodules in vendor/ are licensed under the PostgreSQL license.
|
||||
See vendor/postgres-vX/COPYRIGHT for details.
|
||||
|
||||
@@ -765,12 +765,7 @@ impl ComputeNode {
|
||||
handle_roles(spec, &mut client)?;
|
||||
handle_databases(spec, &mut client)?;
|
||||
handle_role_deletions(spec, connstr.as_str(), &mut client)?;
|
||||
handle_grants(
|
||||
spec,
|
||||
&mut client,
|
||||
connstr.as_str(),
|
||||
self.has_feature(ComputeFeature::AnonExtension),
|
||||
)?;
|
||||
handle_grants(spec, &mut client, connstr.as_str())?;
|
||||
handle_extensions(spec, &mut client)?;
|
||||
handle_extension_neon(&mut client)?;
|
||||
create_availability_check_data(&mut client)?;
|
||||
@@ -844,12 +839,7 @@ impl ComputeNode {
|
||||
handle_roles(&spec, &mut client)?;
|
||||
handle_databases(&spec, &mut client)?;
|
||||
handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?;
|
||||
handle_grants(
|
||||
&spec,
|
||||
&mut client,
|
||||
self.connstr.as_str(),
|
||||
self.has_feature(ComputeFeature::AnonExtension),
|
||||
)?;
|
||||
handle_grants(&spec, &mut client, self.connstr.as_str())?;
|
||||
handle_extensions(&spec, &mut client)?;
|
||||
handle_extension_neon(&mut client)?;
|
||||
// We can skip handle_migrations here because a new migration can only appear
|
||||
@@ -1245,10 +1235,19 @@ LIMIT 100",
|
||||
|
||||
info!("Downloading to shared preload libraries: {:?}", &libs_vec);
|
||||
|
||||
let build_tag_str = if spec
|
||||
.features
|
||||
.contains(&ComputeFeature::RemoteExtensionsUseLatest)
|
||||
{
|
||||
"latest"
|
||||
} else {
|
||||
&self.build_tag
|
||||
};
|
||||
|
||||
let mut download_tasks = Vec::new();
|
||||
for library in &libs_vec {
|
||||
let (ext_name, ext_path) =
|
||||
remote_extensions.get_ext(library, true, &self.build_tag, &self.pgversion)?;
|
||||
remote_extensions.get_ext(library, true, build_tag_str, &self.pgversion)?;
|
||||
download_tasks.push(self.download_extension(ext_name, ext_path));
|
||||
}
|
||||
let results = join_all(download_tasks).await;
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::thread;
|
||||
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
|
||||
use compute_api::requests::ConfigurationRequest;
|
||||
use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError};
|
||||
use compute_api::spec::ComputeFeature;
|
||||
|
||||
use anyhow::Result;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
@@ -171,12 +172,16 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
};
|
||||
|
||||
remote_extensions.get_ext(
|
||||
&filename,
|
||||
is_library,
|
||||
&compute.build_tag,
|
||||
&compute.pgversion,
|
||||
)
|
||||
let build_tag_str = if spec
|
||||
.features
|
||||
.contains(&ComputeFeature::RemoteExtensionsUseLatest)
|
||||
{
|
||||
"latest"
|
||||
} else {
|
||||
&compute.build_tag
|
||||
};
|
||||
|
||||
remote_extensions.get_ext(&filename, is_library, build_tag_str, &compute.pgversion)
|
||||
};
|
||||
|
||||
match ext {
|
||||
|
||||
@@ -264,10 +264,9 @@ pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
|
||||
// case we miss some events for some reason. Not strictly necessary, but
|
||||
// better safe than sorry.
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let watcher_res = notify::recommended_watcher(move |res| {
|
||||
let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
|
||||
let _ = tx.send(res);
|
||||
});
|
||||
let (mut watcher, rx): (Box<dyn Watcher>, _) = match watcher_res {
|
||||
}) {
|
||||
Ok(watcher) => (Box::new(watcher), rx),
|
||||
Err(e) => {
|
||||
match e.kind {
|
||||
|
||||
@@ -581,12 +581,7 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
/// Grant CREATE ON DATABASE to the database owner and do some other alters and grants
|
||||
/// to allow users creating trusted extensions and re-creating `public` schema, for example.
|
||||
#[instrument(skip_all)]
|
||||
pub fn handle_grants(
|
||||
spec: &ComputeSpec,
|
||||
client: &mut Client,
|
||||
connstr: &str,
|
||||
enable_anon_extension: bool,
|
||||
) -> Result<()> {
|
||||
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> {
|
||||
info!("modifying database permissions");
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
|
||||
@@ -683,11 +678,6 @@ pub fn handle_grants(
|
||||
inlinify(&grant_query)
|
||||
);
|
||||
db_client.simple_query(&grant_query)?;
|
||||
|
||||
// it is important to run this after all grants
|
||||
if enable_anon_extension {
|
||||
handle_extension_anon(spec, &db.owner, &mut db_client, false)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -776,7 +766,6 @@ BEGIN
|
||||
END IF;
|
||||
END
|
||||
$$;"#,
|
||||
"GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION",
|
||||
];
|
||||
|
||||
let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
|
||||
@@ -820,125 +809,5 @@ $$;"#,
|
||||
"Ran {} migrations",
|
||||
(migrations.len() - starting_migration_id)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect to the database as superuser and pre-create anon extension
|
||||
/// if it is present in shared_preload_libraries
|
||||
#[instrument(skip_all)]
|
||||
pub fn handle_extension_anon(
|
||||
spec: &ComputeSpec,
|
||||
db_owner: &str,
|
||||
db_client: &mut Client,
|
||||
grants_only: bool,
|
||||
) -> Result<()> {
|
||||
info!("handle extension anon");
|
||||
|
||||
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
|
||||
if libs.contains("anon") {
|
||||
if !grants_only {
|
||||
// check if extension is already initialized using anon.is_initialized()
|
||||
let query = "SELECT anon.is_initialized()";
|
||||
match db_client.query(query, &[]) {
|
||||
Ok(rows) => {
|
||||
if !rows.is_empty() {
|
||||
let is_initialized: bool = rows[0].get(0);
|
||||
if is_initialized {
|
||||
info!("anon extension is already initialized");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"anon extension is_installed check failed with expected error: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Create anon extension if this compute needs it
|
||||
// Users cannot create it themselves, because superuser is required.
|
||||
let mut query = "CREATE EXTENSION IF NOT EXISTS anon CASCADE";
|
||||
info!("creating anon extension with query: {}", query);
|
||||
match db_client.query(query, &[]) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
error!("anon extension creation failed with error: {}", e);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// check that extension is installed
|
||||
query = "SELECT extname FROM pg_extension WHERE extname = 'anon'";
|
||||
let rows = db_client.query(query, &[])?;
|
||||
if rows.is_empty() {
|
||||
error!("anon extension is not installed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Initialize anon extension
|
||||
// This also requires superuser privileges, so users cannot do it themselves.
|
||||
query = "SELECT anon.init()";
|
||||
match db_client.query(query, &[]) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
error!("anon.init() failed with error: {}", e);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check that extension is installed, if not bail early
|
||||
let query = "SELECT extname FROM pg_extension WHERE extname = 'anon'";
|
||||
match db_client.query(query, &[]) {
|
||||
Ok(rows) => {
|
||||
if rows.is_empty() {
|
||||
error!("anon extension is not installed");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("anon extension check failed with error: {}", e);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let query = format!("GRANT ALL ON SCHEMA anon TO {}", db_owner);
|
||||
info!("granting anon extension permissions with query: {}", query);
|
||||
db_client.simple_query(&query)?;
|
||||
|
||||
// Grant permissions to db_owner to use anon extension functions
|
||||
let query = format!("GRANT ALL ON ALL FUNCTIONS IN SCHEMA anon TO {}", db_owner);
|
||||
info!("granting anon extension permissions with query: {}", query);
|
||||
db_client.simple_query(&query)?;
|
||||
|
||||
// This is needed, because some functions are defined as SECURITY DEFINER.
|
||||
// In Postgres SECURITY DEFINER functions are executed with the privileges
|
||||
// of the owner.
|
||||
// In anon extension this it is needed to access some GUCs, which are only accessible to
|
||||
// superuser. But we've patched postgres to allow db_owner to access them as well.
|
||||
// So we need to change owner of these functions to db_owner.
|
||||
let query = format!("
|
||||
SELECT 'ALTER FUNCTION '||nsp.nspname||'.'||p.proname||'('||pg_get_function_identity_arguments(p.oid)||') OWNER TO {};'
|
||||
from pg_proc p
|
||||
join pg_namespace nsp ON p.pronamespace = nsp.oid
|
||||
where nsp.nspname = 'anon';", db_owner);
|
||||
|
||||
info!("change anon extension functions owner to db owner");
|
||||
db_client.simple_query(&query)?;
|
||||
|
||||
// affects views as well
|
||||
let query = format!("GRANT ALL ON ALL TABLES IN SCHEMA anon TO {}", db_owner);
|
||||
info!("granting anon extension permissions with query: {}", query);
|
||||
db_client.simple_query(&query)?;
|
||||
|
||||
let query = format!("GRANT ALL ON ALL SEQUENCES IN SCHEMA anon TO {}", db_owner);
|
||||
info!("granting anon extension permissions with query: {}", query);
|
||||
db_client.simple_query(&query)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ async-trait.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
comfy-table.workspace = true
|
||||
diesel = { version = "2.1.4", features = ["postgres"]}
|
||||
diesel_migrations = { version = "2.1.0", features = ["postgres"]}
|
||||
futures.workspace = true
|
||||
git-version.workspace = true
|
||||
nix.workspace = true
|
||||
|
||||
@@ -7,7 +7,6 @@ CREATE TABLE tenant_shards (
|
||||
generation INTEGER NOT NULL,
|
||||
generation_pageserver BIGINT NOT NULL,
|
||||
placement_policy VARCHAR NOT NULL,
|
||||
splitting SMALLINT NOT NULL,
|
||||
-- config is JSON encoded, opaque to the database.
|
||||
config TEXT NOT NULL
|
||||
);
|
||||
@@ -3,8 +3,7 @@ use crate::service::{Service, STARTUP_RECONCILE_TIMEOUT};
|
||||
use hyper::{Body, Request, Response};
|
||||
use hyper::{StatusCode, Uri};
|
||||
use pageserver_api::models::{
|
||||
TenantCreateRequest, TenantLocationConfigRequest, TenantShardSplitRequest,
|
||||
TimelineCreateRequest,
|
||||
TenantCreateRequest, TenantLocationConfigRequest, TimelineCreateRequest,
|
||||
};
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use pageserver_client::mgmt_api;
|
||||
@@ -42,7 +41,7 @@ pub struct HttpState {
|
||||
|
||||
impl HttpState {
|
||||
pub fn new(service: Arc<crate::service::Service>, auth: Option<Arc<SwappableJwtAuth>>) -> Self {
|
||||
let allowlist_routes = ["/status", "/ready", "/metrics"]
|
||||
let allowlist_routes = ["/status"]
|
||||
.iter()
|
||||
.map(|v| v.parse().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
@@ -280,12 +279,6 @@ async fn handle_node_list(req: Request<Body>) -> Result<Response<Body>, ApiError
|
||||
json_response(StatusCode::OK, state.service.node_list().await?)
|
||||
}
|
||||
|
||||
async fn handle_node_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let state = get_state(&req);
|
||||
let node_id: NodeId = parse_request_param(&req, "node_id")?;
|
||||
json_response(StatusCode::OK, state.service.node_drop(node_id).await?)
|
||||
}
|
||||
|
||||
async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let node_id: NodeId = parse_request_param(&req, "node_id")?;
|
||||
let config_req = json_request::<NodeConfigureRequest>(&mut req).await?;
|
||||
@@ -299,19 +292,6 @@ async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>,
|
||||
json_response(StatusCode::OK, state.service.node_configure(config_req)?)
|
||||
}
|
||||
|
||||
async fn handle_tenant_shard_split(
|
||||
service: Arc<Service>,
|
||||
mut req: Request<Body>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
|
||||
let split_req = json_request::<TenantShardSplitRequest>(&mut req).await?;
|
||||
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
service.tenant_shard_split(tenant_id, split_req).await?,
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_tenant_shard_migrate(
|
||||
service: Arc<Service>,
|
||||
mut req: Request<Body>,
|
||||
@@ -326,29 +306,11 @@ async fn handle_tenant_shard_migrate(
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_tenant_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
|
||||
let state = get_state(&req);
|
||||
|
||||
json_response(StatusCode::OK, state.service.tenant_drop(tenant_id).await?)
|
||||
}
|
||||
|
||||
/// Status endpoint is just used for checking that our HTTP listener is up
|
||||
async fn handle_status(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
/// Readiness endpoint indicates when we're done doing startup I/O (e.g. reconciling
|
||||
/// with remote pageserver nodes). This is intended for use as a kubernetes readiness probe.
|
||||
async fn handle_ready(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let state = get_state(&req);
|
||||
if state.service.startup_complete.is_ready() {
|
||||
json_response(StatusCode::OK, ())
|
||||
} else {
|
||||
json_response(StatusCode::SERVICE_UNAVAILABLE, ())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ReconcileError> for ApiError {
|
||||
fn from(value: ReconcileError) -> Self {
|
||||
ApiError::Conflict(format!("Reconciliation error: {}", value))
|
||||
@@ -404,7 +366,6 @@ pub fn make_router(
|
||||
.data(Arc::new(HttpState::new(service, auth)))
|
||||
// Non-prefixed generic endpoints (status, metrics)
|
||||
.get("/status", |r| request_span(r, handle_status))
|
||||
.get("/ready", |r| request_span(r, handle_ready))
|
||||
// Upcalls for the pageserver: point the pageserver's `control_plane_api` config to this prefix
|
||||
.post("/upcall/v1/re-attach", |r| {
|
||||
request_span(r, handle_re_attach)
|
||||
@@ -415,12 +376,6 @@ pub fn make_router(
|
||||
request_span(r, handle_attach_hook)
|
||||
})
|
||||
.post("/debug/v1/inspect", |r| request_span(r, handle_inspect))
|
||||
.post("/debug/v1/tenant/:tenant_id/drop", |r| {
|
||||
request_span(r, handle_tenant_drop)
|
||||
})
|
||||
.post("/debug/v1/node/:node_id/drop", |r| {
|
||||
request_span(r, handle_node_drop)
|
||||
})
|
||||
.get("/control/v1/tenant/:tenant_id/locate", |r| {
|
||||
tenant_service_handler(r, handle_tenant_locate)
|
||||
})
|
||||
@@ -436,9 +391,6 @@ pub fn make_router(
|
||||
.put("/control/v1/tenant/:tenant_shard_id/migrate", |r| {
|
||||
tenant_service_handler(r, handle_tenant_shard_migrate)
|
||||
})
|
||||
.put("/control/v1/tenant/:tenant_id/shard_split", |r| {
|
||||
tenant_service_handler(r, handle_tenant_shard_split)
|
||||
})
|
||||
// Tenant operations
|
||||
// The ^/v1/ endpoints act as a "Virtual Pageserver", enabling shard-naive clients to call into
|
||||
// this service to manage tenants that actually consist of many tenant shards, as if they are a single entity.
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
pub(crate) mod split_state;
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
|
||||
use self::split_state::SplitState;
|
||||
use camino::Utf8Path;
|
||||
use camino::Utf8PathBuf;
|
||||
use control_plane::attachment_service::{NodeAvailability, NodeSchedulingPolicy};
|
||||
@@ -260,6 +258,7 @@ impl Persistence {
|
||||
|
||||
/// Ordering: call this _after_ deleting the tenant on pageservers, but _before_ dropping state for
|
||||
/// the tenant from memory on this server.
|
||||
#[allow(unused)]
|
||||
pub(crate) async fn delete_tenant(&self, del_tenant_id: TenantId) -> DatabaseResult<()> {
|
||||
use crate::schema::tenant_shards::dsl::*;
|
||||
self.with_conn(move |conn| -> DatabaseResult<()> {
|
||||
@@ -272,18 +271,6 @@ impl Persistence {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
|
||||
use crate::schema::nodes::dsl::*;
|
||||
self.with_conn(move |conn| -> DatabaseResult<()> {
|
||||
diesel::delete(nodes)
|
||||
.filter(node_id.eq(del_node_id.0 as i64))
|
||||
.execute(conn)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient
|
||||
/// batched increment of the generations of all tenants whose generation_pageserver is equal to
|
||||
/// the node that called /re-attach.
|
||||
@@ -376,107 +363,19 @@ impl Persistence {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// When we start shard splitting, we must durably mark the tenant so that
|
||||
// on restart, we know that we must go through recovery.
|
||||
//
|
||||
// We create the child shards here, so that they will be available for increment_generation calls
|
||||
// if some pageserver holding a child shard needs to restart before the overall tenant split is complete.
|
||||
// TODO: when we start shard splitting, we must durably mark the tenant so that
|
||||
// on restart, we know that we must go through recovery (list shards that exist
|
||||
// and pick up where we left off and/or revert to parent shards).
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn begin_shard_split(
|
||||
&self,
|
||||
old_shard_count: ShardCount,
|
||||
split_tenant_id: TenantId,
|
||||
parent_to_children: Vec<(TenantShardId, Vec<TenantShardPersistence>)>,
|
||||
) -> DatabaseResult<()> {
|
||||
use crate::schema::tenant_shards::dsl::*;
|
||||
self.with_conn(move |conn| -> DatabaseResult<()> {
|
||||
conn.transaction(|conn| -> DatabaseResult<()> {
|
||||
// Mark parent shards as splitting
|
||||
|
||||
let expect_parent_records = std::cmp::max(1, old_shard_count.0);
|
||||
|
||||
let updated = diesel::update(tenant_shards)
|
||||
.filter(tenant_id.eq(split_tenant_id.to_string()))
|
||||
.filter(shard_count.eq(old_shard_count.0 as i32))
|
||||
.set((splitting.eq(1),))
|
||||
.execute(conn)?;
|
||||
if u8::try_from(updated)
|
||||
.map_err(|_| DatabaseError::Logical(
|
||||
format!("Overflow existing shard count {} while splitting", updated))
|
||||
)? != expect_parent_records {
|
||||
// Perhaps a deletion or another split raced with this attempt to split, mutating
|
||||
// the parent shards that we intend to split. In this case the split request should fail.
|
||||
return Err(DatabaseError::Logical(
|
||||
format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {expect_parent_records})")
|
||||
));
|
||||
}
|
||||
|
||||
// FIXME: spurious clone to sidestep closure move rules
|
||||
let parent_to_children = parent_to_children.clone();
|
||||
|
||||
// Insert child shards
|
||||
for (parent_shard_id, children) in parent_to_children {
|
||||
let mut parent = crate::schema::tenant_shards::table
|
||||
.filter(tenant_id.eq(parent_shard_id.tenant_id.to_string()))
|
||||
.filter(shard_number.eq(parent_shard_id.shard_number.0 as i32))
|
||||
.filter(shard_count.eq(parent_shard_id.shard_count.0 as i32))
|
||||
.load::<TenantShardPersistence>(conn)?;
|
||||
let parent = if parent.len() != 1 {
|
||||
return Err(DatabaseError::Logical(format!(
|
||||
"Parent shard {parent_shard_id} not found"
|
||||
)));
|
||||
} else {
|
||||
parent.pop().unwrap()
|
||||
};
|
||||
for mut shard in children {
|
||||
// Carry the parent's generation into the child
|
||||
shard.generation = parent.generation;
|
||||
|
||||
debug_assert!(shard.splitting == SplitState::Splitting);
|
||||
diesel::insert_into(tenant_shards)
|
||||
.values(shard)
|
||||
.execute(conn)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
pub(crate) async fn begin_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
|
||||
todo!();
|
||||
}
|
||||
|
||||
// When we finish shard splitting, we must atomically clean up the old shards
|
||||
// TODO: when we finish shard splitting, we must atomically clean up the old shards
|
||||
// and insert the new shards, and clear the splitting marker.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn complete_shard_split(
|
||||
&self,
|
||||
split_tenant_id: TenantId,
|
||||
old_shard_count: ShardCount,
|
||||
) -> DatabaseResult<()> {
|
||||
use crate::schema::tenant_shards::dsl::*;
|
||||
self.with_conn(move |conn| -> DatabaseResult<()> {
|
||||
conn.transaction(|conn| -> QueryResult<()> {
|
||||
// Drop parent shards
|
||||
diesel::delete(tenant_shards)
|
||||
.filter(tenant_id.eq(split_tenant_id.to_string()))
|
||||
.filter(shard_count.eq(old_shard_count.0 as i32))
|
||||
.execute(conn)?;
|
||||
|
||||
// Clear sharding flag
|
||||
let updated = diesel::update(tenant_shards)
|
||||
.filter(tenant_id.eq(split_tenant_id.to_string()))
|
||||
.set((splitting.eq(0),))
|
||||
.execute(conn)?;
|
||||
debug_assert!(updated > 0);
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
pub(crate) async fn complete_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -504,8 +403,6 @@ pub(crate) struct TenantShardPersistence {
|
||||
#[serde(default)]
|
||||
pub(crate) placement_policy: String,
|
||||
#[serde(default)]
|
||||
pub(crate) splitting: SplitState,
|
||||
#[serde(default)]
|
||||
pub(crate) config: String,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
use diesel::pg::{Pg, PgValue};
|
||||
use diesel::{
|
||||
deserialize::FromSql, deserialize::FromSqlRow, expression::AsExpression, serialize::ToSql,
|
||||
sql_types::Int2,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, FromSqlRow, AsExpression)]
|
||||
#[diesel(sql_type = SplitStateSQLRepr)]
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub enum SplitState {
|
||||
Idle = 0,
|
||||
Splitting = 1,
|
||||
}
|
||||
|
||||
impl Default for SplitState {
|
||||
fn default() -> Self {
|
||||
Self::Idle
|
||||
}
|
||||
}
|
||||
|
||||
type SplitStateSQLRepr = Int2;
|
||||
|
||||
impl ToSql<SplitStateSQLRepr, Pg> for SplitState {
|
||||
fn to_sql<'a>(
|
||||
&'a self,
|
||||
out: &'a mut diesel::serialize::Output<Pg>,
|
||||
) -> diesel::serialize::Result {
|
||||
let raw_value: i16 = *self as i16;
|
||||
let mut new_out = out.reborrow();
|
||||
ToSql::<SplitStateSQLRepr, Pg>::to_sql(&raw_value, &mut new_out)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql<SplitStateSQLRepr, Pg> for SplitState {
|
||||
fn from_sql(pg_value: PgValue) -> diesel::deserialize::Result<Self> {
|
||||
match FromSql::<SplitStateSQLRepr, Pg>::from_sql(pg_value).map(|v| match v {
|
||||
0 => Some(Self::Idle),
|
||||
1 => Some(Self::Splitting),
|
||||
_ => None,
|
||||
})? {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(format!("Invalid SplitState value, was: {:?}", pg_value.as_bytes()).into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,6 @@ diesel::table! {
|
||||
generation -> Int4,
|
||||
generation_pageserver -> Int8,
|
||||
placement_policy -> Varchar,
|
||||
splitting -> Int2,
|
||||
config -> Text,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
collections::{BTreeMap, HashMap},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
@@ -24,14 +23,13 @@ use pageserver_api::{
|
||||
models::{
|
||||
LocationConfig, LocationConfigMode, ShardParameters, TenantConfig, TenantCreateRequest,
|
||||
TenantLocationConfigRequest, TenantLocationConfigResponse, TenantShardLocation,
|
||||
TenantShardSplitRequest, TenantShardSplitResponse, TimelineCreateRequest, TimelineInfo,
|
||||
TimelineCreateRequest, TimelineInfo,
|
||||
},
|
||||
shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId},
|
||||
};
|
||||
use pageserver_client::mgmt_api;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::{
|
||||
backoff,
|
||||
completion::Barrier,
|
||||
generation::Generation,
|
||||
http::error::ApiError,
|
||||
@@ -42,11 +40,7 @@ use utils::{
|
||||
use crate::{
|
||||
compute_hook::{self, ComputeHook},
|
||||
node::Node,
|
||||
persistence::{
|
||||
split_state::SplitState, DatabaseError, NodePersistence, Persistence,
|
||||
TenantShardPersistence,
|
||||
},
|
||||
reconciler::attached_location_conf,
|
||||
persistence::{DatabaseError, NodePersistence, Persistence, TenantShardPersistence},
|
||||
scheduler::Scheduler,
|
||||
tenant_state::{
|
||||
IntentState, ObservedState, ObservedStateLocation, ReconcileResult, ReconcileWaitError,
|
||||
@@ -151,71 +145,31 @@ impl Service {
|
||||
// indeterminate, same as in [`ObservedStateLocation`])
|
||||
let mut observed = HashMap::new();
|
||||
|
||||
let mut nodes_online = HashSet::new();
|
||||
|
||||
// TODO: give Service a cancellation token for clean shutdown
|
||||
let cancel = CancellationToken::new();
|
||||
let nodes = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
locked.nodes.clone()
|
||||
};
|
||||
|
||||
// TODO: issue these requests concurrently
|
||||
{
|
||||
let nodes = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
locked.nodes.clone()
|
||||
};
|
||||
for node in nodes.values() {
|
||||
let http_client = reqwest::ClientBuilder::new()
|
||||
.timeout(Duration::from_secs(5))
|
||||
.build()
|
||||
.expect("Failed to construct HTTP client");
|
||||
let client = mgmt_api::Client::from_client(
|
||||
http_client,
|
||||
node.base_url(),
|
||||
self.config.jwt_token.as_deref(),
|
||||
);
|
||||
for node in nodes.values() {
|
||||
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
|
||||
|
||||
fn is_fatal(e: &mgmt_api::Error) -> bool {
|
||||
use mgmt_api::Error::*;
|
||||
match e {
|
||||
ReceiveBody(_) | ReceiveErrorBody(_) => false,
|
||||
ApiError(StatusCode::SERVICE_UNAVAILABLE, _)
|
||||
| ApiError(StatusCode::GATEWAY_TIMEOUT, _)
|
||||
| ApiError(StatusCode::REQUEST_TIMEOUT, _) => false,
|
||||
ApiError(_, _) => true,
|
||||
}
|
||||
tracing::info!("Scanning shards on node {}...", node.id);
|
||||
match client.list_location_config().await {
|
||||
Err(e) => {
|
||||
tracing::warn!("Could not contact pageserver {} ({e})", node.id);
|
||||
// TODO: be more tolerant, apply a generous 5-10 second timeout with retries, in case
|
||||
// pageserver is being restarted at the same time as we are
|
||||
}
|
||||
Ok(listing) => {
|
||||
tracing::info!(
|
||||
"Received {} shard statuses from pageserver {}, setting it to Active",
|
||||
listing.tenant_shards.len(),
|
||||
node.id
|
||||
);
|
||||
|
||||
let list_response = backoff::retry(
|
||||
|| client.list_location_config(),
|
||||
is_fatal,
|
||||
1,
|
||||
5,
|
||||
"Location config listing",
|
||||
&cancel,
|
||||
)
|
||||
.await;
|
||||
let Some(list_response) = list_response else {
|
||||
tracing::info!("Shutdown during startup_reconcile");
|
||||
return;
|
||||
};
|
||||
|
||||
tracing::info!("Scanning shards on node {}...", node.id);
|
||||
match list_response {
|
||||
Err(e) => {
|
||||
tracing::warn!("Could not contact pageserver {} ({e})", node.id);
|
||||
// TODO: be more tolerant, do some retries, in case
|
||||
// pageserver is being restarted at the same time as we are
|
||||
}
|
||||
Ok(listing) => {
|
||||
tracing::info!(
|
||||
"Received {} shard statuses from pageserver {}, setting it to Active",
|
||||
listing.tenant_shards.len(),
|
||||
node.id
|
||||
);
|
||||
nodes_online.insert(node.id);
|
||||
|
||||
for (tenant_shard_id, conf_opt) in listing.tenant_shards {
|
||||
observed.insert(tenant_shard_id, (node.id, conf_opt));
|
||||
}
|
||||
for (tenant_shard_id, conf_opt) in listing.tenant_shards {
|
||||
observed.insert(tenant_shard_id, (node.id, conf_opt));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -226,19 +180,8 @@ impl Service {
|
||||
let mut compute_notifications = Vec::new();
|
||||
|
||||
// Populate intent and observed states for all tenants, based on reported state on pageservers
|
||||
let (shard_count, nodes) = {
|
||||
let shard_count = {
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
|
||||
// Mark nodes online if they responded to us: nodes are offline by default after a restart.
|
||||
let mut nodes = (*locked.nodes).clone();
|
||||
for (node_id, node) in nodes.iter_mut() {
|
||||
if nodes_online.contains(node_id) {
|
||||
node.availability = NodeAvailability::Active;
|
||||
}
|
||||
}
|
||||
locked.nodes = Arc::new(nodes);
|
||||
let nodes = locked.nodes.clone();
|
||||
|
||||
for (tenant_shard_id, (node_id, observed_loc)) in observed {
|
||||
let Some(tenant_state) = locked.tenants.get_mut(&tenant_shard_id) else {
|
||||
cleanup.push((tenant_shard_id, node_id));
|
||||
@@ -270,7 +213,7 @@ impl Service {
|
||||
}
|
||||
}
|
||||
|
||||
(locked.tenants.len(), nodes)
|
||||
locked.tenants.len()
|
||||
};
|
||||
|
||||
// TODO: if any tenant's intent now differs from its loaded generation_pageserver, we should clear that
|
||||
@@ -331,8 +274,9 @@ impl Service {
|
||||
let stream = futures::stream::iter(compute_notifications.into_iter())
|
||||
.map(|(tenant_shard_id, node_id)| {
|
||||
let compute_hook = compute_hook.clone();
|
||||
let cancel = cancel.clone();
|
||||
async move {
|
||||
// TODO: give Service a cancellation token for clean shutdown
|
||||
let cancel = CancellationToken::new();
|
||||
if let Err(e) = compute_hook.notify(tenant_shard_id, node_id, &cancel).await {
|
||||
tracing::error!(
|
||||
tenant_shard_id=%tenant_shard_id,
|
||||
@@ -438,7 +382,7 @@ impl Service {
|
||||
))),
|
||||
config,
|
||||
persistence,
|
||||
startup_complete: startup_complete.clone(),
|
||||
startup_complete,
|
||||
});
|
||||
|
||||
let result_task_this = this.clone();
|
||||
@@ -532,7 +476,6 @@ impl Service {
|
||||
generation_pageserver: i64::MAX,
|
||||
placement_policy: serde_json::to_string(&PlacementPolicy::default()).unwrap(),
|
||||
config: serde_json::to_string(&TenantConfig::default()).unwrap(),
|
||||
splitting: SplitState::default(),
|
||||
};
|
||||
|
||||
match self.persistence.insert_tenant_shards(vec![tsp]).await {
|
||||
@@ -775,7 +718,6 @@ impl Service {
|
||||
generation_pageserver: i64::MAX,
|
||||
placement_policy: serde_json::to_string(&placement_policy).unwrap(),
|
||||
config: serde_json::to_string(&create_req.config).unwrap(),
|
||||
splitting: SplitState::default(),
|
||||
})
|
||||
.collect();
|
||||
self.persistence
|
||||
@@ -1035,10 +977,6 @@ impl Service {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: if we timeout/fail on reconcile, we should still succeed this request,
|
||||
// because otherwise a broken compute hook causes a feedback loop where
|
||||
// location_config returns 500 and gets retried forever.
|
||||
|
||||
if let Some(create_req) = maybe_create {
|
||||
let create_resp = self.tenant_create(create_req).await?;
|
||||
result.shards = create_resp
|
||||
@@ -1162,7 +1100,6 @@ impl Service {
|
||||
self.ensure_attached_wait(tenant_id).await?;
|
||||
|
||||
// TODO: refuse to do this if shard splitting is in progress
|
||||
// (https://github.com/neondatabase/neon/issues/6676)
|
||||
let targets = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
let mut targets = Vec::new();
|
||||
@@ -1243,7 +1180,6 @@ impl Service {
|
||||
self.ensure_attached_wait(tenant_id).await?;
|
||||
|
||||
// TODO: refuse to do this if shard splitting is in progress
|
||||
// (https://github.com/neondatabase/neon/issues/6676)
|
||||
let targets = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
let mut targets = Vec::new();
|
||||
@@ -1416,326 +1352,6 @@ impl Service {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn tenant_shard_split(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
split_req: TenantShardSplitRequest,
|
||||
) -> Result<TenantShardSplitResponse, ApiError> {
|
||||
let mut policy = None;
|
||||
let mut shard_ident = None;
|
||||
|
||||
// TODO: put a cancellation token on Service for clean shutdown
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
// A parent shard which will be split
|
||||
struct SplitTarget {
|
||||
parent_id: TenantShardId,
|
||||
node: Node,
|
||||
child_ids: Vec<TenantShardId>,
|
||||
}
|
||||
|
||||
// Validate input, and calculate which shards we will create
|
||||
let (old_shard_count, targets, compute_hook) = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
|
||||
let pageservers = locked.nodes.clone();
|
||||
|
||||
let mut targets = Vec::new();
|
||||
|
||||
// In case this is a retry, count how many already-split shards we found
|
||||
let mut children_found = Vec::new();
|
||||
let mut old_shard_count = None;
|
||||
|
||||
for (tenant_shard_id, shard) in
|
||||
locked.tenants.range(TenantShardId::tenant_range(tenant_id))
|
||||
{
|
||||
match shard.shard.count.0.cmp(&split_req.new_shard_count) {
|
||||
Ordering::Equal => {
|
||||
// Already split this
|
||||
children_found.push(*tenant_shard_id);
|
||||
continue;
|
||||
}
|
||||
Ordering::Greater => {
|
||||
return Err(ApiError::BadRequest(anyhow::anyhow!(
|
||||
"Requested count {} but already have shards at count {}",
|
||||
split_req.new_shard_count,
|
||||
shard.shard.count.0
|
||||
)));
|
||||
}
|
||||
Ordering::Less => {
|
||||
// Fall through: this shard has lower count than requested,
|
||||
// is a candidate for splitting.
|
||||
}
|
||||
}
|
||||
|
||||
match old_shard_count {
|
||||
None => old_shard_count = Some(shard.shard.count),
|
||||
Some(old_shard_count) => {
|
||||
if old_shard_count != shard.shard.count {
|
||||
// We may hit this case if a caller asked for two splits to
|
||||
// different sizes, before the first one is complete.
|
||||
// e.g. 1->2, 2->4, where the 4 call comes while we have a mixture
|
||||
// of shard_count=1 and shard_count=2 shards in the map.
|
||||
return Err(ApiError::Conflict(
|
||||
"Cannot split, currently mid-split".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
if policy.is_none() {
|
||||
policy = Some(shard.policy.clone());
|
||||
}
|
||||
if shard_ident.is_none() {
|
||||
shard_ident = Some(shard.shard);
|
||||
}
|
||||
|
||||
if tenant_shard_id.shard_count == ShardCount(split_req.new_shard_count) {
|
||||
tracing::info!(
|
||||
"Tenant shard {} already has shard count {}",
|
||||
tenant_shard_id,
|
||||
split_req.new_shard_count
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let node_id =
|
||||
shard
|
||||
.intent
|
||||
.attached
|
||||
.ok_or(ApiError::BadRequest(anyhow::anyhow!(
|
||||
"Cannot split a tenant that is not attached"
|
||||
)))?;
|
||||
|
||||
let node = pageservers
|
||||
.get(&node_id)
|
||||
.expect("Pageservers may not be deleted while referenced");
|
||||
|
||||
// TODO: if any reconciliation is currently in progress for this shard, wait for it.
|
||||
|
||||
targets.push(SplitTarget {
|
||||
parent_id: *tenant_shard_id,
|
||||
node: node.clone(),
|
||||
child_ids: tenant_shard_id.split(ShardCount(split_req.new_shard_count)),
|
||||
});
|
||||
}
|
||||
|
||||
if targets.is_empty() {
|
||||
if children_found.len() == split_req.new_shard_count as usize {
|
||||
return Ok(TenantShardSplitResponse {
|
||||
new_shards: children_found,
|
||||
});
|
||||
} else {
|
||||
// No shards found to split, and no existing children found: the
|
||||
// tenant doesn't exist at all.
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Tenant {} not found", tenant_id).into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
(old_shard_count, targets, locked.compute_hook.clone())
|
||||
};
|
||||
|
||||
// unwrap safety: we would have returned above if we didn't find at least one shard to split
|
||||
let old_shard_count = old_shard_count.unwrap();
|
||||
let shard_ident = shard_ident.unwrap();
|
||||
let policy = policy.unwrap();
|
||||
|
||||
// FIXME: we have dropped self.inner lock, and not yet written anything to the database: another
|
||||
// request could occur here, deleting or mutating the tenant. begin_shard_split checks that the
|
||||
// parent shards exist as expected, but it would be neater to do the above pre-checks within the
|
||||
// same database transaction rather than pre-check in-memory and then maybe-fail the database write.
|
||||
// (https://github.com/neondatabase/neon/issues/6676)
|
||||
|
||||
// Before creating any new child shards in memory or on the pageservers, persist them: this
|
||||
// enables us to ensure that we will always be able to clean up if something goes wrong. This also
|
||||
// acts as the protection against two concurrent attempts to split: one of them will get a database
|
||||
// error trying to insert the child shards.
|
||||
let mut child_tsps = Vec::new();
|
||||
for target in &targets {
|
||||
let mut this_child_tsps = Vec::new();
|
||||
for child in &target.child_ids {
|
||||
let mut child_shard = shard_ident;
|
||||
child_shard.number = child.shard_number;
|
||||
child_shard.count = child.shard_count;
|
||||
|
||||
this_child_tsps.push(TenantShardPersistence {
|
||||
tenant_id: child.tenant_id.to_string(),
|
||||
shard_number: child.shard_number.0 as i32,
|
||||
shard_count: child.shard_count.0 as i32,
|
||||
shard_stripe_size: shard_ident.stripe_size.0 as i32,
|
||||
// Note: this generation is a placeholder, [`Persistence::begin_shard_split`] will
|
||||
// populate the correct generation as part of its transaction, to protect us
|
||||
// against racing with changes in the state of the parent.
|
||||
generation: 0,
|
||||
generation_pageserver: target.node.id.0 as i64,
|
||||
placement_policy: serde_json::to_string(&policy).unwrap(),
|
||||
// TODO: get the config out of the map
|
||||
config: serde_json::to_string(&TenantConfig::default()).unwrap(),
|
||||
splitting: SplitState::Splitting,
|
||||
});
|
||||
}
|
||||
|
||||
child_tsps.push((target.parent_id, this_child_tsps));
|
||||
}
|
||||
|
||||
if let Err(e) = self
|
||||
.persistence
|
||||
.begin_shard_split(old_shard_count, tenant_id, child_tsps)
|
||||
.await
|
||||
{
|
||||
match e {
|
||||
DatabaseError::Query(diesel::result::Error::DatabaseError(
|
||||
DatabaseErrorKind::UniqueViolation,
|
||||
_,
|
||||
)) => {
|
||||
// Inserting a child shard violated a unique constraint: we raced with another call to
|
||||
// this function
|
||||
tracing::warn!("Conflicting attempt to split {tenant_id}: {e}");
|
||||
return Err(ApiError::Conflict("Tenant is already splitting".into()));
|
||||
}
|
||||
_ => return Err(ApiError::InternalServerError(e.into())),
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: we have now committed the shard split state to the database, so any subsequent
|
||||
// failure needs to roll it back. We will later wrap this function in logic to roll back
|
||||
// the split if it fails.
|
||||
// (https://github.com/neondatabase/neon/issues/6676)
|
||||
|
||||
// TODO: issue split calls concurrently (this only matters once we're splitting
|
||||
// N>1 shards into M shards -- initially we're usually splitting 1 shard into N).
|
||||
|
||||
for target in &targets {
|
||||
let SplitTarget {
|
||||
parent_id,
|
||||
node,
|
||||
child_ids,
|
||||
} = target;
|
||||
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
|
||||
let response = client
|
||||
.tenant_shard_split(
|
||||
*parent_id,
|
||||
TenantShardSplitRequest {
|
||||
new_shard_count: split_req.new_shard_count,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| ApiError::Conflict(format!("Failed to split {}: {}", parent_id, e)))?;
|
||||
|
||||
tracing::info!(
|
||||
"Split {} into {}",
|
||||
parent_id,
|
||||
response
|
||||
.new_shards
|
||||
.iter()
|
||||
.map(|s| format!("{:?}", s))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
if &response.new_shards != child_ids {
|
||||
// This should never happen: the pageserver should agree with us on how shard splits work.
|
||||
return Err(ApiError::InternalServerError(anyhow::anyhow!(
|
||||
"Splitting shard {} resulted in unexpected IDs: {:?} (expected {:?})",
|
||||
parent_id,
|
||||
response.new_shards,
|
||||
child_ids
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: if the pageserver restarted concurrently with our split API call,
|
||||
// the actual generation of the child shard might differ from the generation
|
||||
// we expect it to have. In order for our in-database generation to end up
|
||||
// correct, we should carry the child generation back in the response and apply it here
|
||||
// in complete_shard_split (and apply the correct generation in memory)
|
||||
// (or, we can carry generation in the request and reject the request if
|
||||
// it doesn't match, but that requires more retry logic on this side)
|
||||
|
||||
self.persistence
|
||||
.complete_shard_split(tenant_id, old_shard_count)
|
||||
.await?;
|
||||
|
||||
// Replace all the shards we just split with their children
|
||||
let mut response = TenantShardSplitResponse {
|
||||
new_shards: Vec::new(),
|
||||
};
|
||||
let mut child_locations = Vec::new();
|
||||
{
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
for target in targets {
|
||||
let SplitTarget {
|
||||
parent_id,
|
||||
node: _node,
|
||||
child_ids,
|
||||
} = target;
|
||||
let (pageserver, generation, config) = {
|
||||
let old_state = locked
|
||||
.tenants
|
||||
.remove(&parent_id)
|
||||
.expect("It was present, we just split it");
|
||||
(
|
||||
old_state.intent.attached.unwrap(),
|
||||
old_state.generation,
|
||||
old_state.config.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
locked.tenants.remove(&parent_id);
|
||||
|
||||
for child in child_ids {
|
||||
let mut child_shard = shard_ident;
|
||||
child_shard.number = child.shard_number;
|
||||
child_shard.count = child.shard_count;
|
||||
|
||||
let mut child_observed: HashMap<NodeId, ObservedStateLocation> = HashMap::new();
|
||||
child_observed.insert(
|
||||
pageserver,
|
||||
ObservedStateLocation {
|
||||
conf: Some(attached_location_conf(generation, &child_shard, &config)),
|
||||
},
|
||||
);
|
||||
|
||||
let mut child_state = TenantState::new(child, child_shard, policy.clone());
|
||||
child_state.intent = IntentState::single(Some(pageserver));
|
||||
child_state.observed = ObservedState {
|
||||
locations: child_observed,
|
||||
};
|
||||
child_state.generation = generation;
|
||||
child_state.config = config.clone();
|
||||
|
||||
child_locations.push((child, pageserver));
|
||||
|
||||
locked.tenants.insert(child, child_state);
|
||||
response.new_shards.push(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send compute notifications for all the new shards
|
||||
let mut failed_notifications = Vec::new();
|
||||
for (child_id, child_ps) in child_locations {
|
||||
if let Err(e) = compute_hook.notify(child_id, child_ps, &cancel).await {
|
||||
tracing::warn!("Failed to update compute of {}->{} during split, proceeding anyway to complete split ({e})",
|
||||
child_id, child_ps);
|
||||
failed_notifications.push(child_id);
|
||||
}
|
||||
}
|
||||
|
||||
// If we failed any compute notifications, make a note to retry later.
|
||||
if !failed_notifications.is_empty() {
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
for failed in failed_notifications {
|
||||
if let Some(shard) = locked.tenants.get_mut(&failed) {
|
||||
shard.pending_compute_notification = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub(crate) async fn tenant_shard_migrate(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
@@ -1804,45 +1420,6 @@ impl Service {
|
||||
Ok(TenantShardMigrateResponse {})
|
||||
}
|
||||
|
||||
/// This is for debug/support only: we simply drop all state for a tenant, without
|
||||
/// detaching or deleting it on pageservers.
|
||||
pub(crate) async fn tenant_drop(&self, tenant_id: TenantId) -> Result<(), ApiError> {
|
||||
self.persistence.delete_tenant(tenant_id).await?;
|
||||
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
let mut shards = Vec::new();
|
||||
for (tenant_shard_id, _) in locked.tenants.range(TenantShardId::tenant_range(tenant_id)) {
|
||||
shards.push(*tenant_shard_id);
|
||||
}
|
||||
|
||||
for shard in shards {
|
||||
locked.tenants.remove(&shard);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This is for debug/support only: we simply drop all state for a tenant, without
|
||||
/// detaching or deleting it on pageservers. We do not try and re-schedule any
|
||||
/// tenants that were on this node.
|
||||
///
|
||||
/// TODO: proper node deletion API that unhooks things more gracefully
|
||||
pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> {
|
||||
self.persistence.delete_node(node_id).await?;
|
||||
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
|
||||
for shard in locked.tenants.values_mut() {
|
||||
shard.deref_node(node_id);
|
||||
}
|
||||
|
||||
let mut nodes = (*locked.nodes).clone();
|
||||
nodes.remove(&node_id);
|
||||
locked.nodes = Arc::new(nodes);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn node_list(&self) -> Result<Vec<NodePersistence>, ApiError> {
|
||||
// It is convenient to avoid taking the big lock and converting Node to a serializable
|
||||
// structure, by fetching from storage instead of reading in-memory state.
|
||||
|
||||
@@ -193,13 +193,6 @@ impl IntentState {
|
||||
result
|
||||
}
|
||||
|
||||
pub(crate) fn single(node_id: Option<NodeId>) -> Self {
|
||||
Self {
|
||||
attached: node_id,
|
||||
secondary: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// When a node goes offline, we update intents to avoid using it
|
||||
/// as their attached pageserver.
|
||||
///
|
||||
@@ -293,9 +286,6 @@ impl TenantState {
|
||||
// self.intent refers to pageservers that are offline, and pick other
|
||||
// pageservers if so.
|
||||
|
||||
// TODO: respect the splitting bit on tenants: if they are currently splitting then we may not
|
||||
// change their attach location.
|
||||
|
||||
// Build the set of pageservers already in use by this tenant, to avoid scheduling
|
||||
// more work on the same pageservers we're already using.
|
||||
let mut used_pageservers = self.intent.all_pageservers();
|
||||
@@ -534,18 +524,4 @@ impl TenantState {
|
||||
seq: self.sequence,
|
||||
})
|
||||
}
|
||||
|
||||
// If we had any state at all referring to this node ID, drop it. Does not
|
||||
// attempt to reschedule.
|
||||
pub(crate) fn deref_node(&mut self, node_id: NodeId) {
|
||||
if self.intent.attached == Some(node_id) {
|
||||
self.intent.attached = None;
|
||||
}
|
||||
|
||||
self.intent.secondary.retain(|n| n != &node_id);
|
||||
|
||||
self.observed.locations.remove(&node_id);
|
||||
|
||||
debug_assert!(!self.intent.all_pageservers().contains(&node_id));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
use crate::{background_process, local_env::LocalEnv};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use diesel::{
|
||||
backend::Backend,
|
||||
query_builder::{AstPass, QueryFragment, QueryId},
|
||||
Connection, PgConnection, QueryResult, RunQueryDsl,
|
||||
};
|
||||
use diesel_migrations::{HarnessWithOutput, MigrationHarness};
|
||||
use hyper::Method;
|
||||
use pageserver_api::{
|
||||
models::{
|
||||
ShardParameters, TenantCreateRequest, TenantShardSplitRequest, TenantShardSplitResponse,
|
||||
TimelineCreateRequest, TimelineInfo,
|
||||
},
|
||||
models::{ShardParameters, TenantCreateRequest, TimelineCreateRequest, TimelineInfo},
|
||||
shard::TenantShardId,
|
||||
};
|
||||
use pageserver_client::mgmt_api::ResponseErrorMessageExt;
|
||||
use postgres_backend::AuthType;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
use std::{env, str::FromStr};
|
||||
use tokio::process::Command;
|
||||
use tracing::instrument;
|
||||
use url::Url;
|
||||
@@ -267,6 +270,37 @@ impl AttachmentService {
|
||||
.expect("non-Unicode path")
|
||||
}
|
||||
|
||||
/// In order to access database migrations, we need to find the Neon source tree
|
||||
async fn find_source_root(&self) -> anyhow::Result<Utf8PathBuf> {
|
||||
// We assume that either prd or our binary is in the source tree. The former is usually
|
||||
// true for automated test runners, the latter is usually true for developer workstations. Often
|
||||
// both are true, which is fine.
|
||||
let candidate_start_points = [
|
||||
// Current working directory
|
||||
Utf8PathBuf::from_path_buf(std::env::current_dir()?).unwrap(),
|
||||
// Directory containing the binary we're running inside
|
||||
Utf8PathBuf::from_path_buf(env::current_exe()?.parent().unwrap().to_owned()).unwrap(),
|
||||
];
|
||||
|
||||
// For each candidate start point, search through ancestors looking for a neon.git source tree root
|
||||
for start_point in &candidate_start_points {
|
||||
// Start from the build dir: assumes we are running out of a built neon source tree
|
||||
for path in start_point.ancestors() {
|
||||
// A crude approximation: the root of the source tree is whatever contains a "control_plane"
|
||||
// subdirectory.
|
||||
let control_plane = path.join("control_plane");
|
||||
if tokio::fs::try_exists(&control_plane).await? {
|
||||
return Ok(path.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall-through
|
||||
Err(anyhow::anyhow!(
|
||||
"Could not find control_plane src dir, after searching ancestors of {candidate_start_points:?}"
|
||||
))
|
||||
}
|
||||
|
||||
/// Find the directory containing postgres binaries, such as `initdb` and `pg_ctl`
|
||||
///
|
||||
/// This usually uses ATTACHMENT_SERVICE_POSTGRES_VERSION of postgres, but will fall back
|
||||
@@ -306,32 +340,69 @@ impl AttachmentService {
|
||||
///
|
||||
/// Returns the database url
|
||||
pub async fn setup_database(&self) -> anyhow::Result<String> {
|
||||
const DB_NAME: &str = "attachment_service";
|
||||
let database_url = format!("postgresql://localhost:{}/{DB_NAME}", self.postgres_port);
|
||||
let database_url = format!(
|
||||
"postgresql://localhost:{}/attachment_service",
|
||||
self.postgres_port
|
||||
);
|
||||
println!("Running attachment service database setup...");
|
||||
fn change_database_of_url(database_url: &str, default_database: &str) -> (String, String) {
|
||||
let base = ::url::Url::parse(database_url).unwrap();
|
||||
let database = base.path_segments().unwrap().last().unwrap().to_owned();
|
||||
let mut new_url = base.join(default_database).unwrap();
|
||||
new_url.set_query(base.query());
|
||||
(database, new_url.into())
|
||||
}
|
||||
|
||||
let pg_bin_dir = self.get_pg_bin_dir().await?;
|
||||
let createdb_path = pg_bin_dir.join("createdb");
|
||||
let output = Command::new(&createdb_path)
|
||||
.args([
|
||||
"-h",
|
||||
"localhost",
|
||||
"-p",
|
||||
&format!("{}", self.postgres_port),
|
||||
&DB_NAME,
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to spawn createdb");
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreateDatabaseStatement {
|
||||
db_name: String,
|
||||
}
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8(output.stderr).expect("Non-UTF8 output from createdb");
|
||||
if stderr.contains("already exists") {
|
||||
tracing::info!("Database {DB_NAME} already exists");
|
||||
} else {
|
||||
anyhow::bail!("createdb failed with status {}: {stderr}", output.status);
|
||||
impl CreateDatabaseStatement {
|
||||
pub fn new(db_name: &str) -> Self {
|
||||
CreateDatabaseStatement {
|
||||
db_name: db_name.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB: Backend> QueryFragment<DB> for CreateDatabaseStatement {
|
||||
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
|
||||
out.push_sql("CREATE DATABASE ");
|
||||
out.push_identifier(&self.db_name)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Conn> RunQueryDsl<Conn> for CreateDatabaseStatement {}
|
||||
|
||||
impl QueryId for CreateDatabaseStatement {
|
||||
type QueryId = ();
|
||||
|
||||
const HAS_STATIC_QUERY_ID: bool = false;
|
||||
}
|
||||
if PgConnection::establish(&database_url).is_err() {
|
||||
let (database, postgres_url) = change_database_of_url(&database_url, "postgres");
|
||||
println!("Creating database: {database}");
|
||||
let mut conn = PgConnection::establish(&postgres_url)?;
|
||||
CreateDatabaseStatement::new(&database).execute(&mut conn)?;
|
||||
}
|
||||
let mut conn = PgConnection::establish(&database_url)?;
|
||||
|
||||
let migrations_dir = self
|
||||
.find_source_root()
|
||||
.await?
|
||||
.join("control_plane/attachment_service/migrations");
|
||||
|
||||
let migrations = diesel_migrations::FileBasedMigrations::from_path(migrations_dir)?;
|
||||
println!("Running migrations in {}", migrations.path().display());
|
||||
HarnessWithOutput::write_to_stdout(&mut conn)
|
||||
.run_pending_migrations(migrations)
|
||||
.map(|_| ())
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
println!("Migrations complete");
|
||||
|
||||
Ok(database_url)
|
||||
}
|
||||
|
||||
@@ -577,7 +648,7 @@ impl AttachmentService {
|
||||
) -> anyhow::Result<TenantShardMigrateResponse> {
|
||||
self.dispatch(
|
||||
Method::PUT,
|
||||
format!("control/v1/tenant/{tenant_shard_id}/migrate"),
|
||||
format!("tenant/{tenant_shard_id}/migrate"),
|
||||
Some(TenantShardMigrateRequest {
|
||||
tenant_shard_id,
|
||||
node_id,
|
||||
@@ -586,20 +657,6 @@ impl AttachmentService {
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(skip(self), fields(%tenant_id, %new_shard_count))]
|
||||
pub async fn tenant_split(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
new_shard_count: u8,
|
||||
) -> anyhow::Result<TenantShardSplitResponse> {
|
||||
self.dispatch(
|
||||
Method::PUT,
|
||||
format!("control/v1/tenant/{tenant_id}/shard_split"),
|
||||
Some(TenantShardSplitRequest { new_shard_count }),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(node_id=%req.node_id))]
|
||||
pub async fn node_register(&self, req: NodeRegisterRequest) -> anyhow::Result<()> {
|
||||
self.dispatch::<_, ()>(Method::POST, "control/v1/node".to_string(), Some(req))
|
||||
|
||||
@@ -72,6 +72,7 @@ where
|
||||
let log_path = datadir.join(format!("{process_name}.log"));
|
||||
let process_log_file = fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.append(true)
|
||||
.open(&log_path)
|
||||
.with_context(|| {
|
||||
|
||||
@@ -575,26 +575,6 @@ async fn handle_tenant(
|
||||
println!("{tenant_table}");
|
||||
println!("{shard_table}");
|
||||
}
|
||||
Some(("shard-split", matches)) => {
|
||||
let tenant_id = get_tenant_id(matches, env)?;
|
||||
let shard_count: u8 = matches.get_one::<u8>("shard-count").cloned().unwrap_or(0);
|
||||
|
||||
let attachment_service = AttachmentService::from_env(env);
|
||||
let result = attachment_service
|
||||
.tenant_split(tenant_id, shard_count)
|
||||
.await?;
|
||||
println!(
|
||||
"Split tenant {} into shards {}",
|
||||
tenant_id,
|
||||
result
|
||||
.new_shards
|
||||
.iter()
|
||||
.map(|s| format!("{:?}", s))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
}
|
||||
|
||||
Some((sub_name, _)) => bail!("Unexpected tenant subcommand '{}'", sub_name),
|
||||
None => bail!("no tenant subcommand provided"),
|
||||
}
|
||||
@@ -1014,13 +994,12 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
|
||||
.get_one::<String>("endpoint_id")
|
||||
.ok_or_else(|| anyhow!("No endpoint ID was provided to stop"))?;
|
||||
let destroy = sub_args.get_flag("destroy");
|
||||
let mode = sub_args.get_one::<String>("mode").expect("has a default");
|
||||
|
||||
let endpoint = cplane
|
||||
.endpoints
|
||||
.get(endpoint_id.as_str())
|
||||
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
|
||||
endpoint.stop(mode, destroy)?;
|
||||
endpoint.stop(destroy)?;
|
||||
}
|
||||
|
||||
_ => bail!("Unexpected endpoint subcommand '{sub_name}'"),
|
||||
@@ -1304,7 +1283,7 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) {
|
||||
match ComputeControlPlane::load(env.clone()) {
|
||||
Ok(cplane) => {
|
||||
for (_k, node) in cplane.endpoints {
|
||||
if let Err(e) = node.stop(if immediate { "immediate" } else { "fast " }, false) {
|
||||
if let Err(e) = node.stop(false) {
|
||||
eprintln!("postgres stop failed: {e:#}");
|
||||
}
|
||||
}
|
||||
@@ -1545,11 +1524,6 @@ fn cli() -> Command {
|
||||
.subcommand(Command::new("status")
|
||||
.about("Human readable summary of the tenant's shards and attachment locations")
|
||||
.arg(tenant_id_arg.clone()))
|
||||
.subcommand(Command::new("shard-split")
|
||||
.about("Increase the number of shards in the tenant")
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(Arg::new("shard-count").value_parser(value_parser!(u8)).long("shard-count").action(ArgAction::Set).help("Number of shards in the new tenant (default 1)"))
|
||||
)
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("pageserver")
|
||||
@@ -1653,16 +1627,7 @@ fn cli() -> Command {
|
||||
.long("destroy")
|
||||
.action(ArgAction::SetTrue)
|
||||
.required(false)
|
||||
)
|
||||
.arg(
|
||||
Arg::new("mode")
|
||||
.help("Postgres shutdown mode, passed to \"pg_ctl -m <mode>\"")
|
||||
.long("mode")
|
||||
.action(ArgAction::Set)
|
||||
.required(false)
|
||||
.value_parser(["smart", "fast", "immediate"])
|
||||
.default_value("fast")
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
@@ -761,8 +761,22 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop(&self, mode: &str, destroy: bool) -> Result<()> {
|
||||
self.pg_ctl(&["-m", mode, "stop"], &None)?;
|
||||
pub fn stop(&self, destroy: bool) -> Result<()> {
|
||||
// If we are going to destroy data directory,
|
||||
// use immediate shutdown mode, otherwise,
|
||||
// shutdown gracefully to leave the data directory sane.
|
||||
//
|
||||
// Postgres is always started from scratch, so stop
|
||||
// without destroy only used for testing and debugging.
|
||||
//
|
||||
self.pg_ctl(
|
||||
if destroy {
|
||||
&["-m", "immediate", "stop"]
|
||||
} else {
|
||||
&["stop"]
|
||||
},
|
||||
&None,
|
||||
)?;
|
||||
|
||||
// Also wait for the compute_ctl process to die. It might have some
|
||||
// cleanup work to do after postgres stops, like syncing safekeepers,
|
||||
|
||||
@@ -90,8 +90,10 @@ pub enum ComputeFeature {
|
||||
/// track short-lived connections as user activity.
|
||||
ActivityMonitorExperimental,
|
||||
|
||||
/// Pre-install and initialize anon extension for every database in the cluster
|
||||
AnonExtension,
|
||||
// Use latest version of remote extensions
|
||||
// This is needed to allow us to test new versions of extensions before
|
||||
// they are merged into the main branch.
|
||||
RemoteExtensionsUseLatest,
|
||||
|
||||
/// This is a special feature flag that is used to represent unknown feature flags.
|
||||
/// Basically all unknown to enum flags are represented as this one. See unit test
|
||||
@@ -155,8 +157,12 @@ impl RemoteExtSpec {
|
||||
//
|
||||
// Keep it in sync with path generation in
|
||||
// https://github.com/neondatabase/build-custom-extensions/tree/main
|
||||
//
|
||||
// if ComputeFeature::RemoteExtensionsUseLatest is enabled
|
||||
// use "latest" as the build_tag
|
||||
let archive_path_str =
|
||||
format!("{build_tag}/{pg_major_version}/extensions/{real_ext_name}.tar.zst");
|
||||
|
||||
Ok((
|
||||
real_ext_name.to_string(),
|
||||
RemotePath::from_string(&archive_path_str)?,
|
||||
|
||||
@@ -192,16 +192,6 @@ pub struct TimelineCreateRequest {
|
||||
pub pg_version: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TenantShardSplitRequest {
|
||||
pub new_shard_count: u8,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TenantShardSplitResponse {
|
||||
pub new_shards: Vec<TenantShardId>,
|
||||
}
|
||||
|
||||
/// Parameters that apply to all shards in a tenant. Used during tenant creation.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
|
||||
@@ -88,36 +88,12 @@ impl TenantShardId {
|
||||
pub fn is_unsharded(&self) -> bool {
|
||||
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
|
||||
}
|
||||
|
||||
/// Convenience for dropping the tenant_id and just getting the ShardIndex: this
|
||||
/// is useful when logging from code that is already in a span that includes tenant ID, to
|
||||
/// keep messages reasonably terse.
|
||||
pub fn to_index(&self) -> ShardIndex {
|
||||
ShardIndex {
|
||||
shard_number: self.shard_number,
|
||||
shard_count: self.shard_count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the children of this TenantShardId when splitting the overall tenant into
|
||||
/// the given number of shards.
|
||||
pub fn split(&self, new_shard_count: ShardCount) -> Vec<TenantShardId> {
|
||||
let effective_old_shard_count = std::cmp::max(self.shard_count.0, 1);
|
||||
let mut child_shards = Vec::new();
|
||||
for shard_number in 0..ShardNumber(new_shard_count.0).0 {
|
||||
// Key mapping is based on a round robin mapping of key hash modulo shard count,
|
||||
// so our child shards are the ones which the same keys would map to.
|
||||
if shard_number % effective_old_shard_count == self.shard_number.0 {
|
||||
child_shards.push(TenantShardId {
|
||||
tenant_id: self.tenant_id,
|
||||
shard_number: ShardNumber(shard_number),
|
||||
shard_count: new_shard_count,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
child_shards
|
||||
}
|
||||
}
|
||||
|
||||
/// Formatting helper
|
||||
@@ -817,108 +793,4 @@ mod tests {
|
||||
let shard = key_to_shard_number(ShardCount(10), DEFAULT_STRIPE_SIZE, &key);
|
||||
assert_eq!(shard, ShardNumber(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shard_id_split() {
|
||||
let tenant_id = TenantId::generate();
|
||||
let parent = TenantShardId::unsharded(tenant_id);
|
||||
|
||||
// Unsharded into 2
|
||||
assert_eq!(
|
||||
parent.split(ShardCount(2)),
|
||||
vec![
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(2),
|
||||
shard_number: ShardNumber(0)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(2),
|
||||
shard_number: ShardNumber(1)
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
// Unsharded into 4
|
||||
assert_eq!(
|
||||
parent.split(ShardCount(4)),
|
||||
vec![
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(4),
|
||||
shard_number: ShardNumber(0)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(4),
|
||||
shard_number: ShardNumber(1)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(4),
|
||||
shard_number: ShardNumber(2)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(4),
|
||||
shard_number: ShardNumber(3)
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
// count=1 into 2 (check this works the same as unsharded.)
|
||||
let parent = TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(1),
|
||||
shard_number: ShardNumber(0),
|
||||
};
|
||||
assert_eq!(
|
||||
parent.split(ShardCount(2)),
|
||||
vec![
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(2),
|
||||
shard_number: ShardNumber(0)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(2),
|
||||
shard_number: ShardNumber(1)
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
// count=2 into count=8
|
||||
let parent = TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(2),
|
||||
shard_number: ShardNumber(1),
|
||||
};
|
||||
assert_eq!(
|
||||
parent.split(ShardCount(8)),
|
||||
vec![
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(8),
|
||||
shard_number: ShardNumber(1)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(8),
|
||||
shard_number: ShardNumber(3)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(8),
|
||||
shard_number: ShardNumber(5)
|
||||
},
|
||||
TenantShardId {
|
||||
tenant_id,
|
||||
shard_count: ShardCount(8),
|
||||
shard_number: ShardNumber(7)
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,5 @@ rand.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
@@ -7,7 +7,6 @@ pub mod framed;
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
|
||||
|
||||
// re-export for use in utils pageserver_feedback.rs
|
||||
@@ -124,7 +123,7 @@ impl StartupMessageParams {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct CancelKeyData {
|
||||
pub backend_pid: i32,
|
||||
pub cancel_key: i32,
|
||||
|
||||
@@ -191,7 +191,6 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> anyhow::Result<Listing, DownloadError> {
|
||||
// get the passed prefix or if it is not set use prefix_in_bucket value
|
||||
let list_prefix = prefix
|
||||
@@ -224,8 +223,6 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
|
||||
let mut response = builder.into_stream();
|
||||
let mut res = Listing::default();
|
||||
// NonZeroU32 doesn't support subtraction apparently
|
||||
let mut max_keys = max_keys.map(|mk| mk.get());
|
||||
while let Some(l) = response.next().await {
|
||||
let entry = l.map_err(to_download_error)?;
|
||||
let prefix_iter = entry
|
||||
@@ -238,18 +235,7 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
.blobs
|
||||
.blobs()
|
||||
.map(|k| self.name_to_relative_path(&k.name));
|
||||
|
||||
for key in blob_iter {
|
||||
res.keys.push(key);
|
||||
if let Some(mut mk) = max_keys {
|
||||
assert!(mk > 0);
|
||||
mk -= 1;
|
||||
if mk == 0 {
|
||||
return Ok(res); // limit reached
|
||||
}
|
||||
max_keys = Some(mk);
|
||||
}
|
||||
}
|
||||
res.keys.extend(blob_iter);
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
@@ -13,15 +13,9 @@ mod azure_blob;
|
||||
mod local_fs;
|
||||
mod s3_bucket;
|
||||
mod simulate_failures;
|
||||
mod support;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt::Debug,
|
||||
num::{NonZeroU32, NonZeroUsize},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc, time::SystemTime,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
@@ -160,7 +154,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
prefix: Option<&RemotePath>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
let result = self
|
||||
.list(prefix, ListingMode::WithDelimiter, None)
|
||||
.list(prefix, ListingMode::WithDelimiter)
|
||||
.await?
|
||||
.prefixes;
|
||||
Ok(result)
|
||||
@@ -176,17 +170,8 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
/// whereas,
|
||||
/// list_prefixes("foo/bar/") = ["cat", "dog"]
|
||||
/// See `test_real_s3.rs` for more details.
|
||||
///
|
||||
/// max_keys limits max number of keys returned; None means unlimited.
|
||||
async fn list_files(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
let result = self
|
||||
.list(prefix, ListingMode::NoDelimiter, max_keys)
|
||||
.await?
|
||||
.keys;
|
||||
async fn list_files(&self, prefix: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let result = self.list(prefix, ListingMode::NoDelimiter).await?.keys;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
@@ -194,8 +179,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
_mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> Result<Listing, DownloadError>;
|
||||
) -> anyhow::Result<Listing, DownloadError>;
|
||||
|
||||
/// Streams the local file contents into remote into the remote storage entry.
|
||||
async fn upload(
|
||||
@@ -285,19 +269,6 @@ impl std::fmt::Display for DownloadError {
|
||||
|
||||
impl std::error::Error for DownloadError {}
|
||||
|
||||
impl DownloadError {
|
||||
/// Returns true if the error should not be retried with backoff
|
||||
pub fn is_permanent(&self) -> bool {
|
||||
use DownloadError::*;
|
||||
match self {
|
||||
BadInput(_) => true,
|
||||
NotFound => true,
|
||||
Cancelled => true,
|
||||
Other(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TimeTravelError {
|
||||
/// Validation or other error happened due to user input.
|
||||
@@ -353,31 +324,24 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> anyhow::Result<Listing, DownloadError> {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list(prefix, mode, max_keys).await,
|
||||
Self::AwsS3(s) => s.list(prefix, mode, max_keys).await,
|
||||
Self::AzureBlob(s) => s.list(prefix, mode, max_keys).await,
|
||||
Self::Unreliable(s) => s.list(prefix, mode, max_keys).await,
|
||||
Self::LocalFs(s) => s.list(prefix, mode).await,
|
||||
Self::AwsS3(s) => s.list(prefix, mode).await,
|
||||
Self::AzureBlob(s) => s.list(prefix, mode).await,
|
||||
Self::Unreliable(s) => s.list(prefix, mode).await,
|
||||
}
|
||||
}
|
||||
|
||||
// A function for listing all the files in a "directory"
|
||||
// Example:
|
||||
// list_files("foo/bar") = ["foo/bar/a.txt", "foo/bar/b.txt"]
|
||||
//
|
||||
// max_keys limits max number of keys returned; None means unlimited.
|
||||
pub async fn list_files(
|
||||
&self,
|
||||
folder: Option<&RemotePath>,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
pub async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list_files(folder, max_keys).await,
|
||||
Self::AwsS3(s) => s.list_files(folder, max_keys).await,
|
||||
Self::AzureBlob(s) => s.list_files(folder, max_keys).await,
|
||||
Self::Unreliable(s) => s.list_files(folder, max_keys).await,
|
||||
Self::LocalFs(s) => s.list_files(folder).await,
|
||||
Self::AwsS3(s) => s.list_files(folder).await,
|
||||
Self::AzureBlob(s) => s.list_files(folder).await,
|
||||
Self::Unreliable(s) => s.list_files(folder).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@
|
||||
//! This storage used in tests, but can also be used in cases when a certain persistent
|
||||
//! volume is mounted to the local FS.
|
||||
|
||||
use std::{
|
||||
borrow::Cow, future::Future, io::ErrorKind, num::NonZeroU32, pin::Pin, time::SystemTime,
|
||||
};
|
||||
use std::{borrow::Cow, future::Future, io::ErrorKind, pin::Pin, time::SystemTime};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use bytes::Bytes;
|
||||
@@ -20,7 +18,9 @@ use tokio_util::{io::ReaderStream, sync::CancellationToken};
|
||||
use tracing::*;
|
||||
use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty};
|
||||
|
||||
use crate::{Download, DownloadError, Listing, ListingMode, RemotePath, TimeTravelError};
|
||||
use crate::{
|
||||
Download, DownloadError, DownloadStream, Listing, ListingMode, RemotePath, TimeTravelError,
|
||||
};
|
||||
|
||||
use super::{RemoteStorage, StorageMetadata};
|
||||
|
||||
@@ -164,7 +164,6 @@ impl RemoteStorage for LocalFs {
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
let mut result = Listing::default();
|
||||
|
||||
@@ -181,9 +180,6 @@ impl RemoteStorage for LocalFs {
|
||||
!path.is_dir()
|
||||
})
|
||||
.collect();
|
||||
if let Some(max_keys) = max_keys {
|
||||
result.keys.truncate(max_keys.get() as usize);
|
||||
}
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
@@ -369,33 +365,27 @@ impl RemoteStorage for LocalFs {
|
||||
format!("Failed to open source file {target_path:?} to use in the download")
|
||||
})
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let len = source
|
||||
.metadata()
|
||||
.await
|
||||
.context("query file length")
|
||||
.map_err(DownloadError::Other)?
|
||||
.len();
|
||||
|
||||
source
|
||||
.seek(io::SeekFrom::Start(start_inclusive))
|
||||
.await
|
||||
.context("Failed to seek to the range start in a local storage file")
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let metadata = self
|
||||
.read_storage_metadata(&target_path)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive);
|
||||
let source = ReaderStream::new(source);
|
||||
|
||||
let download_stream: DownloadStream = match end_exclusive {
|
||||
Some(end_exclusive) => Box::pin(ReaderStream::new(
|
||||
source.take(end_exclusive - start_inclusive),
|
||||
)),
|
||||
None => Box::pin(ReaderStream::new(source)),
|
||||
};
|
||||
Ok(Download {
|
||||
metadata,
|
||||
last_modified: None,
|
||||
etag: None,
|
||||
download_stream: Box::pin(source),
|
||||
download_stream,
|
||||
})
|
||||
} else {
|
||||
Err(DownloadError::NotFound)
|
||||
@@ -524,8 +514,10 @@ mod fs_tests {
|
||||
use futures_util::Stream;
|
||||
use std::{collections::HashMap, io::Write};
|
||||
|
||||
async fn read_and_check_metadata(
|
||||
async fn read_and_assert_remote_file_contents(
|
||||
storage: &LocalFs,
|
||||
#[allow(clippy::ptr_arg)]
|
||||
// have to use &Utf8PathBuf due to `storage.local_path` parameter requirements
|
||||
remote_storage_path: &RemotePath,
|
||||
expected_metadata: Option<&StorageMetadata>,
|
||||
) -> anyhow::Result<String> {
|
||||
@@ -604,7 +596,7 @@ mod fs_tests {
|
||||
let upload_name = "upload_1";
|
||||
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
|
||||
|
||||
let contents = read_and_check_metadata(&storage, &upload_target, None).await?;
|
||||
let contents = read_and_assert_remote_file_contents(&storage, &upload_target, None).await?;
|
||||
assert_eq!(
|
||||
dummy_contents(upload_name),
|
||||
contents,
|
||||
@@ -626,7 +618,7 @@ mod fs_tests {
|
||||
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
|
||||
|
||||
let full_range_download_contents =
|
||||
read_and_check_metadata(&storage, &upload_target, None).await?;
|
||||
read_and_assert_remote_file_contents(&storage, &upload_target, None).await?;
|
||||
assert_eq!(
|
||||
dummy_contents(upload_name),
|
||||
full_range_download_contents,
|
||||
@@ -668,22 +660,6 @@ mod fs_tests {
|
||||
"Second part bytes should be returned when requested"
|
||||
);
|
||||
|
||||
let suffix_bytes = storage
|
||||
.download_byte_range(&upload_target, 13, None)
|
||||
.await?
|
||||
.download_stream;
|
||||
let suffix_bytes = aggregate(suffix_bytes).await?;
|
||||
let suffix = std::str::from_utf8(&suffix_bytes)?;
|
||||
assert_eq!(upload_name, suffix);
|
||||
|
||||
let all_bytes = storage
|
||||
.download_byte_range(&upload_target, 0, None)
|
||||
.await?
|
||||
.download_stream;
|
||||
let all_bytes = aggregate(all_bytes).await?;
|
||||
let all_bytes = std::str::from_utf8(&all_bytes)?;
|
||||
assert_eq!(dummy_contents("upload_1"), all_bytes);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -760,7 +736,7 @@ mod fs_tests {
|
||||
upload_dummy_file(&storage, upload_name, Some(metadata.clone())).await?;
|
||||
|
||||
let full_range_download_contents =
|
||||
read_and_check_metadata(&storage, &upload_target, Some(&metadata)).await?;
|
||||
read_and_assert_remote_file_contents(&storage, &upload_target, Some(&metadata)).await?;
|
||||
assert_eq!(
|
||||
dummy_contents(upload_name),
|
||||
full_range_download_contents,
|
||||
@@ -796,12 +772,12 @@ mod fs_tests {
|
||||
let child = upload_dummy_file(&storage, "grandparent/parent/child", None).await?;
|
||||
let uncle = upload_dummy_file(&storage, "grandparent/uncle", None).await?;
|
||||
|
||||
let listing = storage.list(None, ListingMode::NoDelimiter, None).await?;
|
||||
let listing = storage.list(None, ListingMode::NoDelimiter).await?;
|
||||
assert!(listing.prefixes.is_empty());
|
||||
assert_eq!(listing.keys, [uncle.clone(), child.clone()].to_vec());
|
||||
|
||||
// Delimiter: should only go one deep
|
||||
let listing = storage.list(None, ListingMode::WithDelimiter, None).await?;
|
||||
let listing = storage.list(None, ListingMode::WithDelimiter).await?;
|
||||
|
||||
assert_eq!(
|
||||
listing.prefixes,
|
||||
@@ -814,7 +790,6 @@ mod fs_tests {
|
||||
.list(
|
||||
Some(&RemotePath::from_string("timelines/some_timeline/grandparent").unwrap()),
|
||||
ListingMode::WithDelimiter,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
num::NonZeroU32,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
@@ -46,9 +45,8 @@ use utils::backoff;
|
||||
|
||||
use super::StorageMetadata;
|
||||
use crate::{
|
||||
support::PermitCarrying, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode,
|
||||
RemotePath, RemoteStorage, S3Config, TimeTravelError, MAX_KEYS_PER_DELETE,
|
||||
REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage,
|
||||
S3Config, TimeTravelError, MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
};
|
||||
|
||||
pub(super) mod metrics;
|
||||
@@ -65,6 +63,7 @@ pub struct S3Bucket {
|
||||
concurrency_limiter: ConcurrencyLimiter,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GetObjectRequest {
|
||||
bucket: String,
|
||||
key: String,
|
||||
@@ -233,8 +232,24 @@ impl S3Bucket {
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
|
||||
let object_output = match get_object {
|
||||
Ok(object_output) => object_output,
|
||||
match get_object {
|
||||
Ok(object_output) => {
|
||||
let metadata = object_output.metadata().cloned().map(StorageMetadata);
|
||||
let etag = object_output.e_tag.clone();
|
||||
let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok());
|
||||
|
||||
let body = object_output.body;
|
||||
let body = ByteStreamAsStream::from(body);
|
||||
let body = PermitCarrying::new(permit, body);
|
||||
let body = TimedDownload::new(started_at, body);
|
||||
|
||||
Ok(Download {
|
||||
metadata,
|
||||
etag,
|
||||
last_modified,
|
||||
download_stream: Box::pin(body),
|
||||
})
|
||||
}
|
||||
Err(SdkError::ServiceError(e)) if matches!(e.err(), GetObjectError::NoSuchKey(_)) => {
|
||||
// Count this in the AttemptOutcome::Ok bucket, because 404 is not
|
||||
// an error: we expect to sometimes fetch an object and find it missing,
|
||||
@@ -244,7 +259,7 @@ impl S3Bucket {
|
||||
AttemptOutcome::Ok,
|
||||
started_at,
|
||||
);
|
||||
return Err(DownloadError::NotFound);
|
||||
Err(DownloadError::NotFound)
|
||||
}
|
||||
Err(e) => {
|
||||
metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
|
||||
@@ -253,27 +268,11 @@ impl S3Bucket {
|
||||
started_at,
|
||||
);
|
||||
|
||||
return Err(DownloadError::Other(
|
||||
Err(DownloadError::Other(
|
||||
anyhow::Error::new(e).context("download s3 object"),
|
||||
));
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let metadata = object_output.metadata().cloned().map(StorageMetadata);
|
||||
let etag = object_output.e_tag;
|
||||
let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok());
|
||||
|
||||
let body = object_output.body;
|
||||
let body = ByteStreamAsStream::from(body);
|
||||
let body = PermitCarrying::new(permit, body);
|
||||
let body = TimedDownload::new(started_at, body);
|
||||
|
||||
Ok(Download {
|
||||
metadata,
|
||||
etag,
|
||||
last_modified,
|
||||
download_stream: Box::pin(body),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_oids(
|
||||
@@ -355,6 +354,33 @@ impl Stream for ByteStreamAsStream {
|
||||
// sense and Stream::size_hint does not really
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
|
||||
struct PermitCarrying<S> {
|
||||
permit: tokio::sync::OwnedSemaphorePermit,
|
||||
#[pin]
|
||||
inner: S,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> PermitCarrying<S> {
|
||||
fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
|
||||
Self { permit, inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = std::io::Result<Bytes>>> Stream for PermitCarrying<S> {
|
||||
type Item = <S as Stream>::Item;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.inner.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// Times and tracks the outcome of the request.
|
||||
struct TimedDownload<S> {
|
||||
@@ -409,11 +435,8 @@ impl RemoteStorage for S3Bucket {
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
let kind = RequestKind::List;
|
||||
// s3 sdk wants i32
|
||||
let mut max_keys = max_keys.map(|mk| mk.get() as i32);
|
||||
let mut result = Listing::default();
|
||||
|
||||
// get the passed prefix or if it is not set use prefix_in_bucket value
|
||||
@@ -437,20 +460,13 @@ impl RemoteStorage for S3Bucket {
|
||||
let _guard = self.permit(kind).await;
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
// min of two Options, returning Some if one is value and another is
|
||||
// None (None is smaller than anything, so plain min doesn't work).
|
||||
let request_max_keys = self
|
||||
.max_keys_per_list_response
|
||||
.into_iter()
|
||||
.chain(max_keys.into_iter())
|
||||
.min();
|
||||
let mut request = self
|
||||
.client
|
||||
.list_objects_v2()
|
||||
.bucket(self.bucket_name.clone())
|
||||
.set_prefix(list_prefix.clone())
|
||||
.set_continuation_token(continuation_token)
|
||||
.set_max_keys(request_max_keys);
|
||||
.set_max_keys(self.max_keys_per_list_response);
|
||||
|
||||
if let ListingMode::WithDelimiter = mode {
|
||||
request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
|
||||
@@ -480,14 +496,6 @@ impl RemoteStorage for S3Bucket {
|
||||
let object_path = object.key().expect("response does not contain a key");
|
||||
let remote_path = self.s3_object_to_relative_path(object_path);
|
||||
result.keys.push(remote_path);
|
||||
if let Some(mut mk) = max_keys {
|
||||
assert!(mk > 0);
|
||||
mk -= 1;
|
||||
if mk == 0 {
|
||||
return Ok(result); // limit reached
|
||||
}
|
||||
max_keys = Some(mk);
|
||||
}
|
||||
}
|
||||
|
||||
result.prefixes.extend(
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
use bytes::Bytes;
|
||||
use futures::stream::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Mutex;
|
||||
use std::time::SystemTime;
|
||||
use std::{collections::hash_map::Entry, sync::Arc};
|
||||
@@ -61,7 +60,7 @@ impl UnreliableWrapper {
|
||||
/// On the first attempts of this operation, return an error. After 'attempts_to_fail'
|
||||
/// attempts, let the operation go ahead, and clear the counter.
|
||||
///
|
||||
fn attempt(&self, op: RemoteOp) -> anyhow::Result<u64> {
|
||||
fn attempt(&self, op: RemoteOp) -> Result<u64, DownloadError> {
|
||||
let mut attempts = self.attempts.lock().unwrap();
|
||||
|
||||
match attempts.entry(op) {
|
||||
@@ -79,13 +78,13 @@ impl UnreliableWrapper {
|
||||
} else {
|
||||
let error =
|
||||
anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());
|
||||
Err(error)
|
||||
Err(DownloadError::Other(error))
|
||||
}
|
||||
}
|
||||
Entry::Vacant(e) => {
|
||||
let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());
|
||||
e.insert(1);
|
||||
Err(error)
|
||||
Err(DownloadError::Other(error))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -106,30 +105,22 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))
|
||||
.map_err(DownloadError::Other)?;
|
||||
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
|
||||
self.inner.list_prefixes(prefix).await
|
||||
}
|
||||
|
||||
async fn list_files(
|
||||
&self,
|
||||
folder: Option<&RemotePath>,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
self.attempt(RemoteOp::ListPrefixes(folder.cloned()))
|
||||
.map_err(DownloadError::Other)?;
|
||||
self.inner.list_files(folder, max_keys).await
|
||||
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
self.attempt(RemoteOp::ListPrefixes(folder.cloned()))?;
|
||||
self.inner.list_files(folder).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))
|
||||
.map_err(DownloadError::Other)?;
|
||||
self.inner.list(prefix, mode, max_keys).await
|
||||
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
|
||||
self.inner.list(prefix, mode).await
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
@@ -146,8 +137,7 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
}
|
||||
|
||||
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
|
||||
self.attempt(RemoteOp::Download(from.clone()))
|
||||
.map_err(DownloadError::Other)?;
|
||||
self.attempt(RemoteOp::Download(from.clone()))?;
|
||||
self.inner.download(from).await
|
||||
}
|
||||
|
||||
@@ -160,8 +150,7 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
// Note: We treat any download_byte_range as an "attempt" of the same
|
||||
// operation. We don't pay attention to the ranges. That's good enough
|
||||
// for now.
|
||||
self.attempt(RemoteOp::Download(from.clone()))
|
||||
.map_err(DownloadError::Other)?;
|
||||
self.attempt(RemoteOp::Download(from.clone()))?;
|
||||
self.inner
|
||||
.download_byte_range(from, start_inclusive, end_exclusive)
|
||||
.await
|
||||
@@ -204,7 +193,7 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned())))
|
||||
.map_err(TimeTravelError::Other)?;
|
||||
.map_err(|e| TimeTravelError::Other(anyhow::Error::new(e)))?;
|
||||
self.inner
|
||||
.time_travel_recover(prefix, timestamp, done_if_after, cancel)
|
||||
.await
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures_util::Stream;
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
|
||||
pub(crate) struct PermitCarrying<S> {
|
||||
permit: tokio::sync::OwnedSemaphorePermit,
|
||||
#[pin]
|
||||
inner: S,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> PermitCarrying<S> {
|
||||
pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
|
||||
Self { permit, inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream> Stream for PermitCarrying<S> {
|
||||
type Item = <S as Stream>::Item;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.inner.size_hint()
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use remote_storage::RemotePath;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashSet, num::NonZeroU32};
|
||||
use test_context::test_context;
|
||||
use tracing::debug;
|
||||
|
||||
@@ -103,7 +103,7 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a
|
||||
let base_prefix =
|
||||
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
|
||||
let root_files = test_client
|
||||
.list_files(None, None)
|
||||
.list_files(None)
|
||||
.await
|
||||
.context("client list root files failure")?
|
||||
.into_iter()
|
||||
@@ -113,17 +113,8 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a
|
||||
ctx.remote_blobs.clone(),
|
||||
"remote storage list_files on root mismatches with the uploads."
|
||||
);
|
||||
|
||||
// Test that max_keys limit works. In total there are about 21 files (see
|
||||
// upload_simple_remote_data call in test_real_s3.rs).
|
||||
let limited_root_files = test_client
|
||||
.list_files(None, Some(NonZeroU32::new(2).unwrap()))
|
||||
.await
|
||||
.context("client list root files failure")?;
|
||||
assert_eq!(limited_root_files.len(), 2);
|
||||
|
||||
let nested_remote_files = test_client
|
||||
.list_files(Some(&base_prefix), None)
|
||||
.list_files(Some(&base_prefix))
|
||||
.await
|
||||
.context("client list nested files failure")?
|
||||
.into_iter()
|
||||
|
||||
@@ -70,7 +70,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
|
||||
}
|
||||
|
||||
async fn list_files(client: &Arc<GenericRemoteStorage>) -> anyhow::Result<HashSet<RemotePath>> {
|
||||
Ok(retry(|| client.list_files(None, None))
|
||||
Ok(retry(|| client.list_files(None))
|
||||
.await
|
||||
.context("list root files failure")?
|
||||
.into_iter()
|
||||
|
||||
@@ -27,11 +27,6 @@ impl Barrier {
|
||||
b.wait().await
|
||||
}
|
||||
}
|
||||
|
||||
/// Return true if a call to wait() would complete immediately
|
||||
pub fn is_ready(&self) -> bool {
|
||||
futures::future::FutureExt::now_or_never(self.0.wait()).is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Barrier {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex, MutexGuard,
|
||||
Arc,
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
@@ -12,7 +12,7 @@ use tokio::sync::Semaphore;
|
||||
///
|
||||
/// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
|
||||
pub struct OnceCell<T> {
|
||||
inner: Mutex<Inner<T>>,
|
||||
inner: tokio::sync::RwLock<Inner<T>>,
|
||||
initializers: AtomicUsize,
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ impl<T> OnceCell<T> {
|
||||
let sem = Semaphore::new(1);
|
||||
sem.close();
|
||||
Self {
|
||||
inner: Mutex::new(Inner {
|
||||
inner: tokio::sync::RwLock::new(Inner {
|
||||
init_semaphore: Arc::new(sem),
|
||||
value: Some(value),
|
||||
}),
|
||||
@@ -61,18 +61,18 @@ impl<T> OnceCell<T> {
|
||||
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
|
||||
/// returning the guard.
|
||||
///
|
||||
/// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
|
||||
/// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization.
|
||||
///
|
||||
/// Initialization is panic-safe and cancellation-safe.
|
||||
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
|
||||
pub async fn get_mut_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardMut<'_, T>, E>
|
||||
where
|
||||
F: FnOnce(InitPermit) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
|
||||
{
|
||||
let sem = {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
let guard = self.inner.write().await;
|
||||
if guard.value.is_some() {
|
||||
return Ok(Guard(guard));
|
||||
return Ok(GuardMut(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
@@ -88,29 +88,72 @@ impl<T> OnceCell<T> {
|
||||
let permit = InitPermit(permit);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.lock().unwrap();
|
||||
let guard = self.inner.write().await;
|
||||
|
||||
Ok(Self::set0(value, guard))
|
||||
}
|
||||
Err(_closed) => {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
let guard = self.inner.write().await;
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(Guard(guard));
|
||||
return Ok(GuardMut(guard));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
|
||||
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
|
||||
/// returning the guard.
|
||||
///
|
||||
/// Initialization is panic-safe and cancellation-safe.
|
||||
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardRef<'_, T>, E>
|
||||
where
|
||||
F: FnOnce(InitPermit) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
|
||||
{
|
||||
let sem = {
|
||||
let guard = self.inner.read().await;
|
||||
if guard.value.is_some() {
|
||||
return Ok(GuardRef(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire_owned().await
|
||||
};
|
||||
|
||||
match permit {
|
||||
Ok(permit) => {
|
||||
let permit = InitPermit(permit);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.write().await;
|
||||
|
||||
Ok(Self::set0(value, guard).downgrade())
|
||||
}
|
||||
Err(_closed) => {
|
||||
let guard = self.inner.read().await;
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(GuardRef(guard));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Assuming a permit is held after previous call to [`GuardMut::take_and_deinit`], it can be used
|
||||
/// to complete initializing the inner value.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the inner has already been initialized.
|
||||
pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> {
|
||||
let guard = self.inner.write().await;
|
||||
|
||||
// cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
|
||||
// give more permits right now.
|
||||
@@ -122,21 +165,31 @@ impl<T> OnceCell<T> {
|
||||
Self::set0(value, guard)
|
||||
}
|
||||
|
||||
fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
|
||||
fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner<T>>) -> GuardMut<'_, T> {
|
||||
if guard.value.is_some() {
|
||||
drop(guard);
|
||||
unreachable!("we won permit, must not be initialized");
|
||||
}
|
||||
guard.value = Some(value);
|
||||
guard.init_semaphore.close();
|
||||
Guard(guard)
|
||||
GuardMut(guard)
|
||||
}
|
||||
|
||||
/// Returns a guard to an existing initialized value, if any.
|
||||
pub fn get(&self) -> Option<Guard<'_, T>> {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
pub async fn get_mut(&self) -> Option<GuardMut<'_, T>> {
|
||||
let guard = self.inner.write().await;
|
||||
if guard.value.is_some() {
|
||||
Some(Guard(guard))
|
||||
Some(GuardMut(guard))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a guard to an existing initialized value, if any.
|
||||
pub async fn get(&self) -> Option<GuardRef<'_, T>> {
|
||||
let guard = self.inner.read().await;
|
||||
if guard.value.is_some() {
|
||||
Some(GuardRef(guard))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -168,9 +221,9 @@ impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
|
||||
/// Uninteresting guard object to allow short-lived access to inspect or clone the held,
|
||||
/// initialized value.
|
||||
#[derive(Debug)]
|
||||
pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
|
||||
pub struct GuardMut<'a, T>(tokio::sync::RwLockWriteGuard<'a, Inner<T>>);
|
||||
|
||||
impl<T> std::ops::Deref for Guard<'_, T> {
|
||||
impl<T> std::ops::Deref for GuardMut<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
@@ -181,7 +234,7 @@ impl<T> std::ops::Deref for Guard<'_, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for Guard<'_, T> {
|
||||
impl<T> std::ops::DerefMut for GuardMut<'_, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0
|
||||
.value
|
||||
@@ -190,7 +243,7 @@ impl<T> std::ops::DerefMut for Guard<'_, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Guard<'a, T> {
|
||||
impl<'a, T> GuardMut<'a, T> {
|
||||
/// Take the current value, and a new permit for it's deinitialization.
|
||||
///
|
||||
/// The permit will be on a semaphore part of the new internal value, and any following
|
||||
@@ -208,6 +261,24 @@ impl<'a, T> Guard<'a, T> {
|
||||
.map(|v| (v, InitPermit(permit)))
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
|
||||
pub fn downgrade(self) -> GuardRef<'a, T> {
|
||||
GuardRef(self.0.downgrade())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GuardRef<'a, T>(tokio::sync::RwLockReadGuard<'a, Inner<T>>);
|
||||
|
||||
impl<T> std::ops::Deref for GuardRef<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0
|
||||
.value
|
||||
.as_ref()
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
/// Type held by OnceCell (de)initializing task.
|
||||
@@ -248,7 +319,7 @@ mod tests {
|
||||
barrier.wait().await;
|
||||
let won = {
|
||||
let g = cell
|
||||
.get_or_init(|permit| {
|
||||
.get_mut_or_init(|permit| {
|
||||
counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
|
||||
async {
|
||||
counters.future_polled.fetch_add(1, Ordering::Relaxed);
|
||||
@@ -295,7 +366,11 @@ mod tests {
|
||||
let cell = cell.clone();
|
||||
let deinitialization_started = deinitialization_started.clone();
|
||||
async move {
|
||||
let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
|
||||
let (answer, _permit) = cell
|
||||
.get_mut()
|
||||
.await
|
||||
.expect("initialized to value")
|
||||
.take_and_deinit();
|
||||
assert_eq!(answer, initial);
|
||||
|
||||
deinitialization_started.wait().await;
|
||||
@@ -306,7 +381,7 @@ mod tests {
|
||||
deinitialization_started.wait().await;
|
||||
|
||||
let started_at = tokio::time::Instant::now();
|
||||
cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
|
||||
cell.get_mut_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -318,21 +393,21 @@ mod tests {
|
||||
|
||||
jh.await.unwrap();
|
||||
|
||||
assert_eq!(*cell.get().unwrap(), reinit);
|
||||
assert_eq!(*cell.get_mut().await.unwrap(), reinit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reinit_with_deinit_permit() {
|
||||
#[tokio::test]
|
||||
async fn reinit_with_deinit_permit() {
|
||||
let cell = Arc::new(OnceCell::new(42));
|
||||
|
||||
let (mol, permit) = cell.get().unwrap().take_and_deinit();
|
||||
cell.set(5, permit);
|
||||
assert_eq!(*cell.get().unwrap(), 5);
|
||||
let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit();
|
||||
cell.set(5, permit).await;
|
||||
assert_eq!(*cell.get_mut().await.unwrap(), 5);
|
||||
|
||||
let (five, permit) = cell.get().unwrap().take_and_deinit();
|
||||
let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit();
|
||||
assert_eq!(5, five);
|
||||
cell.set(mol, permit);
|
||||
assert_eq!(*cell.get().unwrap(), 42);
|
||||
cell.set(mol, permit).await;
|
||||
assert_eq!(*cell.get_mut().await.unwrap(), 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -340,13 +415,13 @@ mod tests {
|
||||
let cell = OnceCell::default();
|
||||
|
||||
for _ in 0..10 {
|
||||
cell.get_or_init(|_permit| async { Err("whatever error") })
|
||||
cell.get_mut_or_init(|_permit| async { Err("whatever error") })
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
let g = cell
|
||||
.get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
|
||||
.get_mut_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(*g, "finally success");
|
||||
@@ -358,7 +433,7 @@ mod tests {
|
||||
|
||||
let barrier = tokio::sync::Barrier::new(2);
|
||||
|
||||
let initializer = cell.get_or_init(|permit| async {
|
||||
let initializer = cell.get_mut_or_init(|permit| async {
|
||||
barrier.wait().await;
|
||||
futures::future::pending::<()>().await;
|
||||
|
||||
@@ -372,10 +447,10 @@ mod tests {
|
||||
|
||||
// now initializer is dropped
|
||||
|
||||
assert!(cell.get().is_none());
|
||||
assert!(cell.get_mut().await.is_none());
|
||||
|
||||
let g = cell
|
||||
.get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
|
||||
.get_mut_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(*g, "now initialized");
|
||||
|
||||
@@ -56,18 +56,10 @@ pub enum ForceAwaitLogicalSize {
|
||||
|
||||
impl Client {
|
||||
pub fn new(mgmt_api_endpoint: String, jwt: Option<&str>) -> Self {
|
||||
Self::from_client(reqwest::Client::new(), mgmt_api_endpoint, jwt)
|
||||
}
|
||||
|
||||
pub fn from_client(
|
||||
client: reqwest::Client,
|
||||
mgmt_api_endpoint: String,
|
||||
jwt: Option<&str>,
|
||||
) -> Self {
|
||||
Self {
|
||||
mgmt_api_endpoint,
|
||||
authorization_header: jwt.map(|jwt| format!("Bearer {jwt}")),
|
||||
client,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,22 +310,6 @@ impl Client {
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn tenant_shard_split(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
req: TenantShardSplitRequest,
|
||||
) -> Result<TenantShardSplitResponse> {
|
||||
let uri = format!(
|
||||
"{}/v1/tenant/{}/shard_split",
|
||||
self.mgmt_api_endpoint, tenant_shard_id
|
||||
);
|
||||
self.request(Method::PUT, &uri, req)
|
||||
.await?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn timeline_list(
|
||||
&self,
|
||||
tenant_shard_id: &TenantShardId,
|
||||
|
||||
@@ -274,10 +274,6 @@ fn start_pageserver(
|
||||
set_launch_timestamp_metric(launch_ts);
|
||||
#[cfg(target_os = "linux")]
|
||||
metrics::register_internal(Box::new(metrics::more_process_metrics::Collector::new())).unwrap();
|
||||
metrics::register_internal(Box::new(
|
||||
pageserver::metrics::tokio_epoll_uring::Collector::new(),
|
||||
))
|
||||
.unwrap();
|
||||
pageserver::preinitialize_metrics();
|
||||
|
||||
// If any failpoints were set from FAILPOINTS environment variable,
|
||||
|
||||
@@ -623,7 +623,6 @@ impl std::fmt::Display for EvictionLayer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct DiskUsageEvictionInfo {
|
||||
/// Timeline's largest layer (remote or resident)
|
||||
pub max_layer_size: Option<u64>,
|
||||
@@ -855,27 +854,19 @@ async fn collect_eviction_candidates(
|
||||
|
||||
let total = tenant_candidates.len();
|
||||
|
||||
let tenant_candidates =
|
||||
tenant_candidates
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, mut candidate)| {
|
||||
// as we iterate this reverse sorted list, the most recently accessed layer will always
|
||||
// be 1.0; this is for us to evict it last.
|
||||
candidate.relative_last_activity =
|
||||
eviction_order.relative_last_activity(total, i);
|
||||
for (i, mut candidate) in tenant_candidates.into_iter().enumerate() {
|
||||
// as we iterate this reverse sorted list, the most recently accessed layer will always
|
||||
// be 1.0; this is for us to evict it last.
|
||||
candidate.relative_last_activity = eviction_order.relative_last_activity(total, i);
|
||||
|
||||
let partition = if cumsum > min_resident_size as i128 {
|
||||
MinResidentSizePartition::Above
|
||||
} else {
|
||||
MinResidentSizePartition::Below
|
||||
};
|
||||
cumsum += i128::from(candidate.layer.get_file_size());
|
||||
|
||||
(partition, candidate)
|
||||
});
|
||||
|
||||
candidates.extend(tenant_candidates);
|
||||
let partition = if cumsum > min_resident_size as i128 {
|
||||
MinResidentSizePartition::Above
|
||||
} else {
|
||||
MinResidentSizePartition::Below
|
||||
};
|
||||
cumsum += i128::from(candidate.layer.get_file_size());
|
||||
candidates.push((partition, candidate));
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the same tenant ID might be hit twice, if it transitions from attached to
|
||||
@@ -891,41 +882,21 @@ async fn collect_eviction_candidates(
|
||||
);
|
||||
|
||||
for secondary_tenant in secondary_tenants {
|
||||
// for secondary tenants we use a sum of on_disk layers and already evicted layers. this is
|
||||
// to prevent repeated disk usage based evictions from completely draining less often
|
||||
// updating secondaries.
|
||||
let (mut layer_info, total_layers) = secondary_tenant.get_layers_for_eviction();
|
||||
|
||||
debug_assert!(
|
||||
total_layers >= layer_info.resident_layers.len(),
|
||||
"total_layers ({total_layers}) must be at least the resident_layers.len() ({})",
|
||||
layer_info.resident_layers.len()
|
||||
);
|
||||
let mut layer_info = secondary_tenant.get_layers_for_eviction();
|
||||
|
||||
layer_info
|
||||
.resident_layers
|
||||
.sort_unstable_by_key(|layer_info| std::cmp::Reverse(layer_info.last_activity_ts));
|
||||
|
||||
let tenant_candidates =
|
||||
layer_info
|
||||
.resident_layers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, mut candidate)| {
|
||||
candidate.relative_last_activity =
|
||||
eviction_order.relative_last_activity(total_layers, i);
|
||||
(
|
||||
// Secondary locations' layers are always considered above the min resident size,
|
||||
// i.e. secondary locations are permitted to be trimmed to zero layers if all
|
||||
// the layers have sufficiently old access times.
|
||||
MinResidentSizePartition::Above,
|
||||
candidate,
|
||||
)
|
||||
});
|
||||
|
||||
candidates.extend(tenant_candidates);
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
candidates.extend(layer_info.resident_layers.into_iter().map(|candidate| {
|
||||
(
|
||||
// Secondary locations' layers are always considered above the min resident size,
|
||||
// i.e. secondary locations are permitted to be trimmed to zero layers if all
|
||||
// the layers have sufficiently old access times.
|
||||
MinResidentSizePartition::Above,
|
||||
candidate,
|
||||
)
|
||||
}));
|
||||
}
|
||||
|
||||
debug_assert!(MinResidentSizePartition::Above < MinResidentSizePartition::Below,
|
||||
|
||||
@@ -19,14 +19,11 @@ use pageserver_api::models::ShardParameters;
|
||||
use pageserver_api::models::TenantDetails;
|
||||
use pageserver_api::models::TenantLocationConfigResponse;
|
||||
use pageserver_api::models::TenantShardLocation;
|
||||
use pageserver_api::models::TenantShardSplitRequest;
|
||||
use pageserver_api::models::TenantShardSplitResponse;
|
||||
use pageserver_api::models::TenantState;
|
||||
use pageserver_api::models::{
|
||||
DownloadRemoteLayersTaskSpawnRequest, LocationConfigMode, TenantAttachRequest,
|
||||
TenantLoadRequest, TenantLocationConfigRequest,
|
||||
};
|
||||
use pageserver_api::shard::ShardCount;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use remote_storage::GenericRemoteStorage;
|
||||
use remote_storage::TimeTravelError;
|
||||
@@ -878,7 +875,7 @@ async fn tenant_reset_handler(
|
||||
let state = get_state(&request);
|
||||
state
|
||||
.tenant_manager
|
||||
.reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), &ctx)
|
||||
.reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), ctx)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
@@ -1107,25 +1104,6 @@ async fn tenant_size_handler(
|
||||
)
|
||||
}
|
||||
|
||||
async fn tenant_shard_split_handler(
|
||||
mut request: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let req: TenantShardSplitRequest = json_request(&mut request).await?;
|
||||
|
||||
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
|
||||
let state = get_state(&request);
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn);
|
||||
|
||||
let new_shards = state
|
||||
.tenant_manager
|
||||
.shard_split(tenant_shard_id, ShardCount(req.new_shard_count), &ctx)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, TenantShardSplitResponse { new_shards })
|
||||
}
|
||||
|
||||
async fn layer_map_info_handler(
|
||||
request: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
@@ -2085,9 +2063,6 @@ pub fn make_router(
|
||||
.put("/v1/tenant/config", |r| {
|
||||
api_handler(r, update_tenant_config_handler)
|
||||
})
|
||||
.put("/v1/tenant/:tenant_shard_id/shard_split", |r| {
|
||||
api_handler(r, tenant_shard_split_handler)
|
||||
})
|
||||
.get("/v1/tenant/:tenant_shard_id/config", |r| {
|
||||
api_handler(r, get_tenant_config_handler)
|
||||
})
|
||||
|
||||
@@ -2400,72 +2400,6 @@ impl<F: Future<Output = Result<O, E>>, O, E> Future for MeasuredRemoteOp<F> {
|
||||
}
|
||||
}
|
||||
|
||||
pub mod tokio_epoll_uring {
|
||||
use metrics::UIntGauge;
|
||||
|
||||
pub struct Collector {
|
||||
descs: Vec<metrics::core::Desc>,
|
||||
systems_created: UIntGauge,
|
||||
systems_destroyed: UIntGauge,
|
||||
}
|
||||
|
||||
const NMETRICS: usize = 2;
|
||||
|
||||
impl metrics::core::Collector for Collector {
|
||||
fn desc(&self) -> Vec<&metrics::core::Desc> {
|
||||
self.descs.iter().collect()
|
||||
}
|
||||
|
||||
fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
|
||||
let mut mfs = Vec::with_capacity(NMETRICS);
|
||||
let tokio_epoll_uring::metrics::Metrics {
|
||||
systems_created,
|
||||
systems_destroyed,
|
||||
} = tokio_epoll_uring::metrics::global();
|
||||
self.systems_created.set(systems_created);
|
||||
mfs.extend(self.systems_created.collect());
|
||||
self.systems_destroyed.set(systems_destroyed);
|
||||
mfs.extend(self.systems_destroyed.collect());
|
||||
mfs
|
||||
}
|
||||
}
|
||||
|
||||
impl Collector {
|
||||
#[allow(clippy::new_without_default)]
|
||||
pub fn new() -> Self {
|
||||
let mut descs = Vec::new();
|
||||
|
||||
let systems_created = UIntGauge::new(
|
||||
"pageserver_tokio_epoll_uring_systems_created",
|
||||
"counter of tokio-epoll-uring systems that were created",
|
||||
)
|
||||
.unwrap();
|
||||
descs.extend(
|
||||
metrics::core::Collector::desc(&systems_created)
|
||||
.into_iter()
|
||||
.cloned(),
|
||||
);
|
||||
|
||||
let systems_destroyed = UIntGauge::new(
|
||||
"pageserver_tokio_epoll_uring_systems_destroyed",
|
||||
"counter of tokio-epoll-uring systems that were destroyed",
|
||||
)
|
||||
.unwrap();
|
||||
descs.extend(
|
||||
metrics::core::Collector::desc(&systems_destroyed)
|
||||
.into_iter()
|
||||
.cloned(),
|
||||
);
|
||||
|
||||
Self {
|
||||
descs,
|
||||
systems_created,
|
||||
systems_destroyed,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preinitialize_metrics() {
|
||||
// Python tests need these and on some we do alerting.
|
||||
//
|
||||
|
||||
@@ -91,8 +91,8 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000);
|
||||
/// `tokio_tar` already read the first such block. Read the second all-zeros block,
|
||||
/// and check that there is no more data after the EOF marker.
|
||||
///
|
||||
/// 'tar' command can also write extra blocks of zeros, up to a record
|
||||
/// size, controlled by the --record-size argument. Ignore them too.
|
||||
/// XXX: Currently, any trailing data after the EOF marker prints a warning.
|
||||
/// Perhaps it should be a hard error?
|
||||
async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<()> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
let mut buf = [0u8; 512];
|
||||
@@ -113,24 +113,17 @@ async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<()
|
||||
anyhow::bail!("invalid tar EOF marker");
|
||||
}
|
||||
|
||||
// Drain any extra zero-blocks after the EOF marker
|
||||
// Drain any data after the EOF marker
|
||||
let mut trailing_bytes = 0;
|
||||
let mut seen_nonzero_bytes = false;
|
||||
loop {
|
||||
let nbytes = reader.read(&mut buf).await?;
|
||||
trailing_bytes += nbytes;
|
||||
if !buf.iter().all(|&x| x == 0) {
|
||||
seen_nonzero_bytes = true;
|
||||
}
|
||||
if nbytes == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if seen_nonzero_bytes {
|
||||
anyhow::bail!("unexpected non-zero bytes after the tar archive");
|
||||
}
|
||||
if trailing_bytes % 512 != 0 {
|
||||
anyhow::bail!("unexpected number of zeros ({trailing_bytes}), not divisible by tar block size (512 bytes), after the tar archive");
|
||||
if trailing_bytes > 0 {
|
||||
warn!("ignored {trailing_bytes} unexpected bytes after the tar archive");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -576,8 +576,8 @@ pub fn shutdown_token() -> CancellationToken {
|
||||
|
||||
/// Has the current task been requested to shut down?
|
||||
pub fn is_shutdown_requested() -> bool {
|
||||
if let Ok(true_or_false) = SHUTDOWN_TOKEN.try_with(|t| t.is_cancelled()) {
|
||||
true_or_false
|
||||
if let Ok(cancel) = SHUTDOWN_TOKEN.try_with(|t| t.clone()) {
|
||||
cancel.is_cancelled()
|
||||
} else {
|
||||
if !cfg!(test) {
|
||||
warn!("is_shutdown_requested() called in an unexpected task or thread");
|
||||
|
||||
@@ -53,7 +53,6 @@ use self::metadata::TimelineMetadata;
|
||||
use self::mgr::GetActiveTenantError;
|
||||
use self::mgr::GetTenantError;
|
||||
use self::mgr::TenantsMap;
|
||||
use self::remote_timeline_client::upload::upload_index_part;
|
||||
use self::remote_timeline_client::RemoteTimelineClient;
|
||||
use self::timeline::uninit::TimelineExclusionError;
|
||||
use self::timeline::uninit::TimelineUninitMark;
|
||||
@@ -1377,7 +1376,7 @@ impl Tenant {
|
||||
async move {
|
||||
debug!("starting index part download");
|
||||
|
||||
let index_part = client.download_index_file(&cancel_clone).await;
|
||||
let index_part = client.download_index_file(cancel_clone).await;
|
||||
|
||||
debug!("finished index part download");
|
||||
|
||||
@@ -2398,67 +2397,6 @@ impl Tenant {
|
||||
pub(crate) fn get_generation(&self) -> Generation {
|
||||
self.generation
|
||||
}
|
||||
|
||||
/// This function partially shuts down the tenant (it shuts down the Timelines) and is fallible,
|
||||
/// and can leave the tenant in a bad state if it fails. The caller is responsible for
|
||||
/// resetting this tenant to a valid state if we fail.
|
||||
pub(crate) async fn split_prepare(
|
||||
&self,
|
||||
child_shards: &Vec<TenantShardId>,
|
||||
) -> anyhow::Result<()> {
|
||||
let timelines = self.timelines.lock().unwrap().clone();
|
||||
for timeline in timelines.values() {
|
||||
let Some(tl_client) = &timeline.remote_client else {
|
||||
anyhow::bail!("Remote storage is mandatory");
|
||||
};
|
||||
|
||||
let Some(remote_storage) = &self.remote_storage else {
|
||||
anyhow::bail!("Remote storage is mandatory");
|
||||
};
|
||||
|
||||
// We do not block timeline creation/deletion during splits inside the pageserver: it is up to higher levels
|
||||
// to ensure that they do not start a split if currently in the process of doing these.
|
||||
|
||||
// Upload an index from the parent: this is partly to provide freshness for the
|
||||
// child tenants that will copy it, and partly for general ease-of-debugging: there will
|
||||
// always be a parent shard index in the same generation as we wrote the child shard index.
|
||||
tl_client.schedule_index_upload_for_file_changes()?;
|
||||
tl_client.wait_completion().await?;
|
||||
|
||||
// Shut down the timeline's remote client: this means that the indices we write
|
||||
// for child shards will not be invalidated by the parent shard deleting layers.
|
||||
tl_client.shutdown().await?;
|
||||
|
||||
// Download methods can still be used after shutdown, as they don't flow through the remote client's
|
||||
// queue. In principal the RemoteTimelineClient could provide this without downloading it, but this
|
||||
// operation is rare, so it's simpler to just download it (and robustly guarantees that the index
|
||||
// we use here really is the remotely persistent one).
|
||||
let result = tl_client
|
||||
.download_index_file(&self.cancel)
|
||||
.instrument(info_span!("download_index_file", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), timeline_id=%timeline.timeline_id))
|
||||
.await?;
|
||||
let index_part = match result {
|
||||
MaybeDeletedIndexPart::Deleted(_) => {
|
||||
anyhow::bail!("Timeline deletion happened concurrently with split")
|
||||
}
|
||||
MaybeDeletedIndexPart::IndexPart(p) => p,
|
||||
};
|
||||
|
||||
for child_shard in child_shards {
|
||||
upload_index_part(
|
||||
remote_storage,
|
||||
child_shard,
|
||||
&timeline.timeline_id,
|
||||
self.generation,
|
||||
&index_part,
|
||||
&self.cancel,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a Vec of timelines and their ancestors (timeline_id, ancestor_id),
|
||||
@@ -3794,10 +3732,6 @@ impl Tenant {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get_tenant_conf(&self) -> TenantConfOpt {
|
||||
self.tenant_conf.read().unwrap().tenant_conf
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_timeline_and_uninit_mark(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//! page server.
|
||||
|
||||
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
|
||||
use itertools::Itertools;
|
||||
use pageserver_api::key::Key;
|
||||
use pageserver_api::models::ShardParameters;
|
||||
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, TenantShardId};
|
||||
@@ -23,7 +22,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::*;
|
||||
|
||||
use remote_storage::GenericRemoteStorage;
|
||||
use utils::{completion, crashsafe};
|
||||
use utils::crashsafe;
|
||||
|
||||
use crate::config::PageServerConf;
|
||||
use crate::context::{DownloadBehavior, RequestContext};
|
||||
@@ -645,6 +644,8 @@ pub(crate) async fn shutdown_all_tenants() {
|
||||
}
|
||||
|
||||
async fn shutdown_all_tenants0(tenants: &std::sync::RwLock<TenantsMap>) {
|
||||
use utils::completion;
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
|
||||
@@ -1199,7 +1200,7 @@ impl TenantManager {
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
drop_cache: bool,
|
||||
ctx: &RequestContext,
|
||||
ctx: RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
let Some(old_slot) = slot_guard.get_old_value() else {
|
||||
@@ -1252,7 +1253,7 @@ impl TenantManager {
|
||||
None,
|
||||
self.tenants,
|
||||
SpawnMode::Normal,
|
||||
ctx,
|
||||
&ctx,
|
||||
)?;
|
||||
|
||||
slot_guard.upsert(TenantSlot::Attached(tenant))?;
|
||||
@@ -1374,164 +1375,6 @@ impl TenantManager {
|
||||
slot_guard.revert();
|
||||
result
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), new_shard_count=%new_shard_count.0))]
|
||||
pub(crate) async fn shard_split(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
new_shard_count: ShardCount,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Vec<TenantShardId>> {
|
||||
let tenant = get_tenant(tenant_shard_id, true)?;
|
||||
|
||||
// Plan: identify what the new child shards will be
|
||||
let effective_old_shard_count = std::cmp::max(tenant_shard_id.shard_count.0, 1);
|
||||
if new_shard_count <= ShardCount(effective_old_shard_count) {
|
||||
anyhow::bail!("Requested shard count is not an increase");
|
||||
}
|
||||
let expansion_factor = new_shard_count.0 / effective_old_shard_count;
|
||||
if !expansion_factor.is_power_of_two() {
|
||||
anyhow::bail!("Requested split is not a power of two");
|
||||
}
|
||||
|
||||
let parent_shard_identity = tenant.shard_identity;
|
||||
let parent_tenant_conf = tenant.get_tenant_conf();
|
||||
let parent_generation = tenant.generation;
|
||||
|
||||
let child_shards = tenant_shard_id.split(new_shard_count);
|
||||
tracing::info!(
|
||||
"Shard {} splits into: {}",
|
||||
tenant_shard_id.to_index(),
|
||||
child_shards
|
||||
.iter()
|
||||
.map(|id| format!("{}", id.to_index()))
|
||||
.join(",")
|
||||
);
|
||||
|
||||
// Phase 1: Write out child shards' remote index files, in the parent tenant's current generation
|
||||
if let Err(e) = tenant.split_prepare(&child_shards).await {
|
||||
// If [`Tenant::split_prepare`] fails, we must reload the tenant, because it might
|
||||
// have been left in a partially-shut-down state.
|
||||
tracing::warn!("Failed to prepare for split: {e}, reloading Tenant before returning");
|
||||
self.reset_tenant(tenant_shard_id, false, ctx).await?;
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
self.resources.deletion_queue_client.flush_advisory();
|
||||
|
||||
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
|
||||
drop(tenant);
|
||||
let mut parent_slot_guard =
|
||||
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
let parent = match parent_slot_guard.get_old_value() {
|
||||
Some(TenantSlot::Attached(t)) => t,
|
||||
Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"),
|
||||
Some(TenantSlot::InProgress(_)) => {
|
||||
// tenant_map_acquire_slot never returns InProgress, if a slot was InProgress
|
||||
// it would return an error.
|
||||
unreachable!()
|
||||
}
|
||||
None => {
|
||||
// We don't actually need the parent shard to still be attached to do our work, but it's
|
||||
// a weird enough situation that the caller probably didn't want us to continue working
|
||||
// if they had detached the tenant they requested the split on.
|
||||
anyhow::bail!("Detached parent shard in the middle of split!")
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: hardlink layers from the parent into the child shard directories so that they don't immediately re-download
|
||||
// TODO: erase the dentries from the parent
|
||||
|
||||
// Take a snapshot of where the parent's WAL ingest had got to: we will wait for
|
||||
// child shards to reach this point.
|
||||
let mut target_lsns = HashMap::new();
|
||||
for timeline in parent.timelines.lock().unwrap().clone().values() {
|
||||
target_lsns.insert(timeline.timeline_id, timeline.get_last_record_lsn());
|
||||
}
|
||||
|
||||
// TODO: we should have the parent shard stop its WAL ingest here, it's a waste of resources
|
||||
// and could slow down the children trying to catch up.
|
||||
|
||||
// Phase 3: Spawn the child shards
|
||||
for child_shard in &child_shards {
|
||||
let mut child_shard_identity = parent_shard_identity;
|
||||
child_shard_identity.count = child_shard.shard_count;
|
||||
child_shard_identity.number = child_shard.shard_number;
|
||||
|
||||
let child_location_conf = LocationConf {
|
||||
mode: LocationMode::Attached(AttachedLocationConfig {
|
||||
generation: parent_generation,
|
||||
attach_mode: AttachmentMode::Single,
|
||||
}),
|
||||
shard: child_shard_identity,
|
||||
tenant_conf: parent_tenant_conf,
|
||||
};
|
||||
|
||||
self.upsert_location(
|
||||
*child_shard,
|
||||
child_location_conf,
|
||||
None,
|
||||
SpawnMode::Normal,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Phase 4: wait for child chards WAL ingest to catch up to target LSN
|
||||
for child_shard_id in &child_shards {
|
||||
let child_shard = {
|
||||
let locked = TENANTS.read().unwrap();
|
||||
let peek_slot =
|
||||
tenant_map_peek_slot(&locked, child_shard_id, TenantSlotPeekMode::Read)?;
|
||||
peek_slot.and_then(|s| s.get_attached()).cloned()
|
||||
};
|
||||
if let Some(t) = child_shard {
|
||||
let timelines = t.timelines.lock().unwrap().clone();
|
||||
for timeline in timelines.values() {
|
||||
let Some(target_lsn) = target_lsns.get(&timeline.timeline_id) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
"Waiting for child shard {}/{} to reach target lsn {}...",
|
||||
child_shard_id,
|
||||
timeline.timeline_id,
|
||||
target_lsn
|
||||
);
|
||||
if let Err(e) = timeline.wait_lsn(*target_lsn, ctx).await {
|
||||
// Failure here might mean shutdown, in any case this part is an optimization
|
||||
// and we shouldn't hold up the split operation.
|
||||
tracing::warn!(
|
||||
"Failed to wait for timeline {} to reach lsn {target_lsn}: {e}",
|
||||
timeline.timeline_id
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Child shard {}/{} reached target lsn {}",
|
||||
child_shard_id,
|
||||
timeline.timeline_id,
|
||||
target_lsn
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Shut down the parent shard.
|
||||
let (_guard, progress) = completion::channel();
|
||||
match parent.shutdown(progress, false).await {
|
||||
Ok(()) => {}
|
||||
Err(other) => {
|
||||
other.wait().await;
|
||||
}
|
||||
}
|
||||
parent_slot_guard.drop_old_value()?;
|
||||
|
||||
// Phase 6: Release the InProgress on the parent shard
|
||||
drop(parent_slot_guard);
|
||||
|
||||
Ok(child_shards)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -2366,6 +2209,8 @@ async fn remove_tenant_from_memory<V, F>(
|
||||
where
|
||||
F: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
use utils::completion;
|
||||
|
||||
let mut slot_guard =
|
||||
tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?;
|
||||
|
||||
|
||||
@@ -217,7 +217,6 @@ use crate::metrics::{
|
||||
};
|
||||
use crate::task_mgr::shutdown_token;
|
||||
use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id;
|
||||
use crate::tenant::remote_timeline_client::download::download_retry;
|
||||
use crate::tenant::storage_layer::AsLayerDesc;
|
||||
use crate::tenant::upload_queue::Delete;
|
||||
use crate::tenant::TIMELINES_SEGMENT_NAME;
|
||||
@@ -263,11 +262,6 @@ pub(crate) const INITDB_PRESERVED_PATH: &str = "initdb-preserved.tar.zst";
|
||||
/// Default buffer size when interfacing with [`tokio::fs::File`].
|
||||
pub(crate) const BUFFER_SIZE: usize = 32 * 1024;
|
||||
|
||||
/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not
|
||||
/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that.
|
||||
pub(crate) const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
pub(crate) const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
pub enum MaybeDeletedIndexPart {
|
||||
IndexPart(IndexPart),
|
||||
Deleted(IndexPart),
|
||||
@@ -331,6 +325,11 @@ pub struct RemoteTimelineClient {
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not
|
||||
/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that.
|
||||
const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
/// Wrapper for timeout_cancellable that flattens result and converts TimeoutCancellableError to anyhow.
|
||||
///
|
||||
/// This is a convenience for the various upload functions. In future
|
||||
@@ -507,7 +506,7 @@ impl RemoteTimelineClient {
|
||||
/// Download index file
|
||||
pub async fn download_index_file(
|
||||
&self,
|
||||
cancel: &CancellationToken,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<MaybeDeletedIndexPart, DownloadError> {
|
||||
let _unfinished_gauge_guard = self.metrics.call_begin(
|
||||
&RemoteOpFileKind::Index,
|
||||
@@ -1148,17 +1147,22 @@ impl RemoteTimelineClient {
|
||||
|
||||
let cancel = shutdown_token();
|
||||
|
||||
let remaining = download_retry(
|
||||
let remaining = backoff::retry(
|
||||
|| async {
|
||||
self.storage_impl
|
||||
.list_files(Some(&timeline_storage_path), None)
|
||||
.list_files(Some(&timeline_storage_path))
|
||||
.await
|
||||
},
|
||||
"list remaining files",
|
||||
|_e| false,
|
||||
FAILED_DOWNLOAD_WARN_THRESHOLD,
|
||||
FAILED_REMOTE_OP_RETRIES,
|
||||
"list_prefixes",
|
||||
&cancel,
|
||||
)
|
||||
.await
|
||||
.context("list files remaining files")?;
|
||||
.ok_or_else(|| anyhow::anyhow!("Cancelled!"))
|
||||
.and_then(|x| x)
|
||||
.context("list prefixes")?;
|
||||
|
||||
// We will delete the current index_part object last, since it acts as a deletion
|
||||
// marker via its deleted_at attribute
|
||||
@@ -1347,7 +1351,6 @@ impl RemoteTimelineClient {
|
||||
/// queue.
|
||||
///
|
||||
async fn perform_upload_task(self: &Arc<Self>, task: Arc<UploadTask>) {
|
||||
let cancel = shutdown_token();
|
||||
// Loop to retry until it completes.
|
||||
loop {
|
||||
// If we're requested to shut down, close up shop and exit.
|
||||
@@ -1359,7 +1362,7 @@ impl RemoteTimelineClient {
|
||||
// the Future, but we're not 100% sure if the remote storage library
|
||||
// is cancellation safe, so we don't dare to do that. Hopefully, the
|
||||
// upload finishes or times out soon enough.
|
||||
if cancel.is_cancelled() {
|
||||
if task_mgr::is_shutdown_requested() {
|
||||
info!("upload task cancelled by shutdown request");
|
||||
match self.stop() {
|
||||
Ok(()) => {}
|
||||
@@ -1470,7 +1473,7 @@ impl RemoteTimelineClient {
|
||||
retries,
|
||||
DEFAULT_BASE_BACKOFF_SECONDS,
|
||||
DEFAULT_MAX_BACKOFF_SECONDS,
|
||||
&cancel,
|
||||
&shutdown_token(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1987,7 +1990,7 @@ mod tests {
|
||||
|
||||
// Download back the index.json, and check that the list of files is correct
|
||||
let initial_index_part = match client
|
||||
.download_index_file(&CancellationToken::new())
|
||||
.download_index_file(CancellationToken::new())
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
@@ -2081,7 +2084,7 @@ mod tests {
|
||||
|
||||
// Download back the index.json, and check that the list of files is correct
|
||||
let index_part = match client
|
||||
.download_index_file(&CancellationToken::new())
|
||||
.download_index_file(CancellationToken::new())
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
@@ -2283,7 +2286,7 @@ mod tests {
|
||||
let client = test_state.build_client(get_generation);
|
||||
|
||||
let download_r = client
|
||||
.download_index_file(&CancellationToken::new())
|
||||
.download_index_file(CancellationToken::new())
|
||||
.await
|
||||
.expect("download should always succeed");
|
||||
assert!(matches!(download_r, MaybeDeletedIndexPart::IndexPart(_)));
|
||||
|
||||
@@ -216,15 +216,16 @@ pub async fn list_remote_timelines(
|
||||
anyhow::bail!("storage-sync-list-remote-timelines");
|
||||
});
|
||||
|
||||
let cancel_inner = cancel.clone();
|
||||
let listing = download_retry_forever(
|
||||
|| {
|
||||
download_cancellable(
|
||||
&cancel,
|
||||
storage.list(Some(&remote_path), ListingMode::WithDelimiter, None),
|
||||
&cancel_inner,
|
||||
storage.list(Some(&remote_path), ListingMode::WithDelimiter),
|
||||
)
|
||||
},
|
||||
&format!("list timelines for {tenant_shard_id}"),
|
||||
&cancel,
|
||||
cancel,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -257,18 +258,19 @@ async fn do_download_index_part(
|
||||
tenant_shard_id: &TenantShardId,
|
||||
timeline_id: &TimelineId,
|
||||
index_generation: Generation,
|
||||
cancel: &CancellationToken,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<IndexPart, DownloadError> {
|
||||
use futures::stream::StreamExt;
|
||||
|
||||
let remote_path = remote_index_path(tenant_shard_id, timeline_id, index_generation);
|
||||
|
||||
let cancel_inner = cancel.clone();
|
||||
let index_part_bytes = download_retry_forever(
|
||||
|| async {
|
||||
// Cancellation: if is safe to cancel this future because we're just downloading into
|
||||
// a memory buffer, not touching local disk.
|
||||
let index_part_download =
|
||||
download_cancellable(cancel, storage.download(&remote_path)).await?;
|
||||
download_cancellable(&cancel_inner, storage.download(&remote_path)).await?;
|
||||
|
||||
let mut index_part_bytes = Vec::new();
|
||||
let mut stream = std::pin::pin!(index_part_download.download_stream);
|
||||
@@ -286,7 +288,7 @@ async fn do_download_index_part(
|
||||
.await?;
|
||||
|
||||
let index_part: IndexPart = serde_json::from_slice(&index_part_bytes)
|
||||
.with_context(|| format!("deserialize index part file at {remote_path:?}"))
|
||||
.with_context(|| format!("download index part file at {remote_path:?}"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
Ok(index_part)
|
||||
@@ -303,7 +305,7 @@ pub(super) async fn download_index_part(
|
||||
tenant_shard_id: &TenantShardId,
|
||||
timeline_id: &TimelineId,
|
||||
my_generation: Generation,
|
||||
cancel: &CancellationToken,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<IndexPart, DownloadError> {
|
||||
debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
|
||||
@@ -323,8 +325,14 @@ pub(super) async fn download_index_part(
|
||||
// index in our generation.
|
||||
//
|
||||
// This is an optimization to avoid doing the listing for the general case below.
|
||||
let res =
|
||||
do_download_index_part(storage, tenant_shard_id, timeline_id, my_generation, cancel).await;
|
||||
let res = do_download_index_part(
|
||||
storage,
|
||||
tenant_shard_id,
|
||||
timeline_id,
|
||||
my_generation,
|
||||
cancel.clone(),
|
||||
)
|
||||
.await;
|
||||
match res {
|
||||
Ok(index_part) => {
|
||||
tracing::debug!(
|
||||
@@ -349,7 +357,7 @@ pub(super) async fn download_index_part(
|
||||
tenant_shard_id,
|
||||
timeline_id,
|
||||
my_generation.previous(),
|
||||
cancel,
|
||||
cancel.clone(),
|
||||
)
|
||||
.await;
|
||||
match res {
|
||||
@@ -371,13 +379,18 @@ pub(super) async fn download_index_part(
|
||||
// objects, and select the highest one with a generation <= my_generation. Constructing the prefix is equivalent
|
||||
// to constructing a full index path with no generation, because the generation is a suffix.
|
||||
let index_prefix = remote_index_path(tenant_shard_id, timeline_id, Generation::none());
|
||||
|
||||
let indices = download_retry(
|
||||
|| async { storage.list_files(Some(&index_prefix), None).await },
|
||||
"list index_part files",
|
||||
cancel,
|
||||
let indices = backoff::retry(
|
||||
|| async { storage.list_files(Some(&index_prefix)).await },
|
||||
|_| false,
|
||||
FAILED_DOWNLOAD_WARN_THRESHOLD,
|
||||
FAILED_REMOTE_OP_RETRIES,
|
||||
"listing index_part files",
|
||||
&cancel,
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.ok_or_else(|| anyhow::anyhow!("Cancelled"))
|
||||
.and_then(|x| x)
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
// General case logic for which index to use: the latest index whose generation
|
||||
// is <= our own. See "Finding the remote indices for timelines" in docs/rfcs/025-generation-numbers.md
|
||||
@@ -434,6 +447,8 @@ pub(crate) async fn download_initdb_tar_zst(
|
||||
"{INITDB_PATH}.download-{timeline_id}.{TEMP_FILE_SUFFIX}"
|
||||
));
|
||||
|
||||
let cancel_inner = cancel.clone();
|
||||
|
||||
let file = download_retry(
|
||||
|| async {
|
||||
let file = OpenOptions::new()
|
||||
@@ -446,11 +461,13 @@ pub(crate) async fn download_initdb_tar_zst(
|
||||
.with_context(|| format!("tempfile creation {temp_path}"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let download = match download_cancellable(cancel, storage.download(&remote_path)).await
|
||||
let download = match download_cancellable(&cancel_inner, storage.download(&remote_path))
|
||||
.await
|
||||
{
|
||||
Ok(dl) => dl,
|
||||
Err(DownloadError::NotFound) => {
|
||||
download_cancellable(cancel, storage.download(&remote_preserved_path)).await?
|
||||
download_cancellable(&cancel_inner, storage.download(&remote_preserved_path))
|
||||
.await?
|
||||
}
|
||||
Err(other) => Err(other)?,
|
||||
};
|
||||
@@ -499,7 +516,7 @@ pub(crate) async fn download_initdb_tar_zst(
|
||||
/// with backoff.
|
||||
///
|
||||
/// (See similar logic for uploads in `perform_upload_task`)
|
||||
pub(super) async fn download_retry<T, O, F>(
|
||||
async fn download_retry<T, O, F>(
|
||||
op: O,
|
||||
description: &str,
|
||||
cancel: &CancellationToken,
|
||||
@@ -510,7 +527,7 @@ where
|
||||
{
|
||||
backoff::retry(
|
||||
op,
|
||||
DownloadError::is_permanent,
|
||||
|e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound),
|
||||
FAILED_DOWNLOAD_WARN_THRESHOLD,
|
||||
FAILED_REMOTE_OP_RETRIES,
|
||||
description,
|
||||
@@ -524,7 +541,7 @@ where
|
||||
async fn download_retry_forever<T, O, F>(
|
||||
op: O,
|
||||
description: &str,
|
||||
cancel: &CancellationToken,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<T, DownloadError>
|
||||
where
|
||||
O: FnMut() -> F,
|
||||
@@ -532,11 +549,11 @@ where
|
||||
{
|
||||
backoff::retry(
|
||||
op,
|
||||
DownloadError::is_permanent,
|
||||
|e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound),
|
||||
FAILED_DOWNLOAD_WARN_THRESHOLD,
|
||||
u32::MAX,
|
||||
description,
|
||||
cancel,
|
||||
&cancel,
|
||||
)
|
||||
.await
|
||||
.ok_or_else(|| DownloadError::Cancelled)
|
||||
|
||||
@@ -27,7 +27,7 @@ use super::index::LayerFileMetadata;
|
||||
use tracing::info;
|
||||
|
||||
/// Serializes and uploads the given index part data to the remote storage.
|
||||
pub(crate) async fn upload_index_part<'a>(
|
||||
pub(super) async fn upload_index_part<'a>(
|
||||
storage: &'a GenericRemoteStorage,
|
||||
tenant_shard_id: &TenantShardId,
|
||||
timeline_id: &TimelineId,
|
||||
|
||||
@@ -160,7 +160,7 @@ impl SecondaryTenant {
|
||||
&self.tenant_shard_id
|
||||
}
|
||||
|
||||
pub(crate) fn get_layers_for_eviction(self: &Arc<Self>) -> (DiskUsageEvictionInfo, usize) {
|
||||
pub(crate) fn get_layers_for_eviction(self: &Arc<Self>) -> DiskUsageEvictionInfo {
|
||||
self.detail.lock().unwrap().get_layers_for_eviction(self)
|
||||
}
|
||||
|
||||
|
||||
@@ -146,15 +146,14 @@ impl SecondaryDetail {
|
||||
}
|
||||
}
|
||||
|
||||
/// Additionally returns the total number of layers, used for more stable relative access time
|
||||
/// based eviction.
|
||||
pub(super) fn get_layers_for_eviction(
|
||||
&self,
|
||||
parent: &Arc<SecondaryTenant>,
|
||||
) -> (DiskUsageEvictionInfo, usize) {
|
||||
let mut result = DiskUsageEvictionInfo::default();
|
||||
let mut total_layers = 0;
|
||||
|
||||
) -> DiskUsageEvictionInfo {
|
||||
let mut result = DiskUsageEvictionInfo {
|
||||
max_layer_size: None,
|
||||
resident_layers: Vec::new(),
|
||||
};
|
||||
for (timeline_id, timeline_detail) in &self.timelines {
|
||||
result
|
||||
.resident_layers
|
||||
@@ -170,10 +169,6 @@ impl SecondaryDetail {
|
||||
relative_last_activity: finite_f32::FiniteF32::ZERO,
|
||||
}
|
||||
}));
|
||||
|
||||
// total might be missing currently downloading layers, but as a lower than actual
|
||||
// value it is good enough approximation.
|
||||
total_layers += timeline_detail.on_disk_layers.len() + timeline_detail.evicted_at.len();
|
||||
}
|
||||
result.max_layer_size = result
|
||||
.resident_layers
|
||||
@@ -188,7 +183,7 @@ impl SecondaryDetail {
|
||||
result.resident_layers.len()
|
||||
);
|
||||
|
||||
(result, total_layers)
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,7 +312,9 @@ impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCo
|
||||
.tenant_manager
|
||||
.get_secondary_tenant_shard(*tenant_shard_id);
|
||||
let Some(tenant) = tenant else {
|
||||
return Err(anyhow::anyhow!("Not found or not in Secondary mode"));
|
||||
{
|
||||
return Err(anyhow::anyhow!("Not found or not in Secondary mode"));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(PendingDownload {
|
||||
@@ -392,9 +389,9 @@ impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCo
|
||||
}
|
||||
|
||||
CompleteDownload {
|
||||
secondary_state,
|
||||
completed_at: Instant::now(),
|
||||
}
|
||||
secondary_state,
|
||||
completed_at: Instant::now(),
|
||||
}
|
||||
}.instrument(info_span!(parent: None, "secondary_download", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))))
|
||||
}
|
||||
}
|
||||
@@ -533,7 +530,7 @@ impl<'a> TenantDownloader<'a> {
|
||||
.map_err(UpdateError::from)?;
|
||||
let mut heatmap_bytes = Vec::new();
|
||||
let mut body = tokio_util::io::StreamReader::new(download.download_stream);
|
||||
let _size = tokio::io::copy_buf(&mut body, &mut heatmap_bytes).await?;
|
||||
let _size = tokio::io::copy(&mut body, &mut heatmap_bytes).await?;
|
||||
Ok(heatmap_bytes)
|
||||
},
|
||||
|e| matches!(e, UpdateError::NoData | UpdateError::Cancelled),
|
||||
|
||||
@@ -300,8 +300,8 @@ impl Layer {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
|
||||
self.0.info(reset)
|
||||
pub(crate) async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
|
||||
self.0.info(reset).await
|
||||
}
|
||||
|
||||
pub(crate) fn access_stats(&self) -> &LayerAccessStats {
|
||||
@@ -612,10 +612,10 @@ impl LayerInner {
|
||||
let mut rx = self.status.subscribe();
|
||||
|
||||
let strong = {
|
||||
match self.inner.get() {
|
||||
match self.inner.get_mut().await {
|
||||
Some(mut either) => {
|
||||
self.wanted_evicted.store(true, Ordering::Relaxed);
|
||||
either.downgrade()
|
||||
ResidentOrWantedEvicted::downgrade(&mut either)
|
||||
}
|
||||
None => return Err(EvictionError::NotFound),
|
||||
}
|
||||
@@ -641,7 +641,7 @@ impl LayerInner {
|
||||
// use however late (compared to the initial expressing of wanted) as the
|
||||
// "outcome" now
|
||||
LAYER_IMPL_METRICS.inc_broadcast_lagged();
|
||||
match self.inner.get() {
|
||||
match self.inner.get_mut().await {
|
||||
Some(_) => Err(EvictionError::Downloaded),
|
||||
None => Ok(()),
|
||||
}
|
||||
@@ -759,7 +759,7 @@ impl LayerInner {
|
||||
// use the already held initialization permit because it is impossible to hit the
|
||||
// below paths anymore essentially limiting the max loop iterations to 2.
|
||||
let (value, init_permit) = download(init_permit).await?;
|
||||
let mut guard = self.inner.set(value, init_permit);
|
||||
let mut guard = self.inner.set(value, init_permit).await;
|
||||
let (strong, _upgraded) = guard
|
||||
.get_and_upgrade()
|
||||
.expect("init creates strong reference, we held the init permit");
|
||||
@@ -767,7 +767,7 @@ impl LayerInner {
|
||||
}
|
||||
|
||||
let (weak, permit) = {
|
||||
let mut locked = self.inner.get_or_init(download).await?;
|
||||
let mut locked = self.inner.get_mut_or_init(download).await?;
|
||||
|
||||
if let Some((strong, upgraded)) = locked.get_and_upgrade() {
|
||||
if upgraded {
|
||||
@@ -989,12 +989,12 @@ impl LayerInner {
|
||||
}
|
||||
}
|
||||
|
||||
fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
|
||||
async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
|
||||
let layer_file_name = self.desc.filename().file_name();
|
||||
|
||||
// this is not accurate: we could have the file locally but there was a cancellation
|
||||
// and now we are not in sync, or we are currently downloading it.
|
||||
let remote = self.inner.get().is_none();
|
||||
let remote = self.inner.get_mut().await.is_none();
|
||||
|
||||
let access_stats = self.access_stats.as_api_model(reset);
|
||||
|
||||
@@ -1053,7 +1053,7 @@ impl LayerInner {
|
||||
LAYER_IMPL_METRICS.inc_eviction_cancelled(EvictionCancelled::LayerGone);
|
||||
return;
|
||||
};
|
||||
match this.evict_blocking(version) {
|
||||
match tokio::runtime::Handle::current().block_on(this.evict_blocking(version)) {
|
||||
Ok(()) => LAYER_IMPL_METRICS.inc_completed_evictions(),
|
||||
Err(reason) => LAYER_IMPL_METRICS.inc_eviction_cancelled(reason),
|
||||
}
|
||||
@@ -1061,7 +1061,7 @@ impl LayerInner {
|
||||
}
|
||||
}
|
||||
|
||||
fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> {
|
||||
async fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> {
|
||||
// deleted or detached timeline, don't do anything.
|
||||
let Some(timeline) = self.timeline.upgrade() else {
|
||||
return Err(EvictionCancelled::TimelineGone);
|
||||
@@ -1070,7 +1070,7 @@ impl LayerInner {
|
||||
// to avoid starting a new download while we evict, keep holding on to the
|
||||
// permit.
|
||||
let _permit = {
|
||||
let maybe_downloaded = self.inner.get();
|
||||
let maybe_downloaded = self.inner.get_mut().await;
|
||||
|
||||
let (_weak, permit) = match maybe_downloaded {
|
||||
Some(mut guard) => {
|
||||
|
||||
@@ -1268,7 +1268,7 @@ impl Timeline {
|
||||
let mut historic_layers = Vec::new();
|
||||
for historic_layer in layer_map.iter_historic_layers() {
|
||||
let historic_layer = guard.get_from_desc(&historic_layer);
|
||||
historic_layers.push(historic_layer.info(reset));
|
||||
historic_layers.push(historic_layer.info(reset).await);
|
||||
}
|
||||
|
||||
LayerMapInfo {
|
||||
|
||||
@@ -314,9 +314,6 @@ lfc_change_limit_hook(int newval, void *extra)
|
||||
lfc_ctl->used -= 1;
|
||||
}
|
||||
lfc_ctl->limit = new_size;
|
||||
if (new_size == 0) {
|
||||
lfc_ctl->generation += 1;
|
||||
}
|
||||
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
|
||||
|
||||
LWLockRelease(lfc_lock);
|
||||
|
||||
133
pgxn/neon/neon.c
133
pgxn/neon/neon.c
@@ -11,23 +11,16 @@
|
||||
#include "postgres.h"
|
||||
#include "fmgr.h"
|
||||
|
||||
#include "miscadmin.h"
|
||||
#include "access/xact.h"
|
||||
#include "access/xlog.h"
|
||||
#include "storage/buf_internals.h"
|
||||
#include "storage/bufmgr.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "postmaster/bgworker.h"
|
||||
#include "postmaster/interrupt.h"
|
||||
#include "replication/slot.h"
|
||||
#include "replication/walsender.h"
|
||||
#include "storage/procsignal.h"
|
||||
#include "tcop/tcopprot.h"
|
||||
#include "funcapi.h"
|
||||
#include "access/htup_details.h"
|
||||
#include "utils/pg_lsn.h"
|
||||
#include "utils/guc.h"
|
||||
#include "utils/wait_event.h"
|
||||
|
||||
#include "neon.h"
|
||||
#include "walproposer.h"
|
||||
@@ -37,130 +30,6 @@
|
||||
PG_MODULE_MAGIC;
|
||||
void _PG_init(void);
|
||||
|
||||
static int logical_replication_max_time_lag = 3600;
|
||||
|
||||
static void
|
||||
InitLogicalReplicationMonitor(void)
|
||||
{
|
||||
BackgroundWorker bgw;
|
||||
|
||||
DefineCustomIntVariable(
|
||||
"neon.logical_replication_max_time_lag",
|
||||
"Threshold for dropping unused logical replication slots",
|
||||
NULL,
|
||||
&logical_replication_max_time_lag,
|
||||
3600, 0, INT_MAX,
|
||||
PGC_SIGHUP,
|
||||
GUC_UNIT_S,
|
||||
NULL, NULL, NULL);
|
||||
|
||||
memset(&bgw, 0, sizeof(bgw));
|
||||
bgw.bgw_flags = BGWORKER_SHMEM_ACCESS;
|
||||
bgw.bgw_start_time = BgWorkerStart_RecoveryFinished;
|
||||
snprintf(bgw.bgw_library_name, BGW_MAXLEN, "neon");
|
||||
snprintf(bgw.bgw_function_name, BGW_MAXLEN, "LogicalSlotsMonitorMain");
|
||||
snprintf(bgw.bgw_name, BGW_MAXLEN, "Logical replication monitor");
|
||||
snprintf(bgw.bgw_type, BGW_MAXLEN, "Logical replication monitor");
|
||||
bgw.bgw_restart_time = 5;
|
||||
bgw.bgw_notify_pid = 0;
|
||||
bgw.bgw_main_arg = (Datum) 0;
|
||||
|
||||
RegisterBackgroundWorker(&bgw);
|
||||
}
|
||||
|
||||
typedef struct
|
||||
{
|
||||
NameData name;
|
||||
bool dropped;
|
||||
XLogRecPtr confirmed_flush_lsn;
|
||||
TimestampTz last_updated;
|
||||
} SlotStatus;
|
||||
|
||||
/*
|
||||
* Unused logical replication slots pins WAL and prevents deletion of snapshots.
|
||||
*/
|
||||
PGDLLEXPORT void
|
||||
LogicalSlotsMonitorMain(Datum main_arg)
|
||||
{
|
||||
SlotStatus* slots;
|
||||
TimestampTz now, last_checked;
|
||||
|
||||
/* Establish signal handlers. */
|
||||
pqsignal(SIGUSR1, procsignal_sigusr1_handler);
|
||||
pqsignal(SIGHUP, SignalHandlerForConfigReload);
|
||||
pqsignal(SIGTERM, die);
|
||||
|
||||
BackgroundWorkerUnblockSignals();
|
||||
|
||||
slots = (SlotStatus*)calloc(max_replication_slots, sizeof(SlotStatus));
|
||||
last_checked = GetCurrentTimestamp();
|
||||
|
||||
for (;;)
|
||||
{
|
||||
(void) WaitLatch(MyLatch,
|
||||
WL_LATCH_SET | WL_EXIT_ON_PM_DEATH | WL_TIMEOUT,
|
||||
logical_replication_max_time_lag*1000/2,
|
||||
PG_WAIT_EXTENSION);
|
||||
ResetLatch(MyLatch);
|
||||
CHECK_FOR_INTERRUPTS();
|
||||
|
||||
now = GetCurrentTimestamp();
|
||||
|
||||
if (now - last_checked > logical_replication_max_time_lag*USECS_PER_SEC)
|
||||
{
|
||||
int n_active_slots = 0;
|
||||
last_checked = now;
|
||||
|
||||
LWLockAcquire(ReplicationSlotControlLock, LW_SHARED);
|
||||
for (int i = 0; i < max_replication_slots; i++)
|
||||
{
|
||||
ReplicationSlot *s = &ReplicationSlotCtl->replication_slots[i];
|
||||
|
||||
/* Consider only logical repliction slots */
|
||||
if (!s->in_use || !SlotIsLogical(s))
|
||||
continue;
|
||||
|
||||
if (s->active_pid != 0)
|
||||
{
|
||||
n_active_slots += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Check if there was some activity with the slot since last check */
|
||||
if (s->data.confirmed_flush != slots[i].confirmed_flush_lsn)
|
||||
{
|
||||
slots[i].confirmed_flush_lsn = s->data.confirmed_flush;
|
||||
slots[i].last_updated = now;
|
||||
}
|
||||
else if (now - slots[i].last_updated > logical_replication_max_time_lag*USECS_PER_SEC)
|
||||
{
|
||||
slots[i].name = s->data.name;
|
||||
slots[i].dropped = true;
|
||||
}
|
||||
}
|
||||
LWLockRelease(ReplicationSlotControlLock);
|
||||
|
||||
/*
|
||||
* If there are no active subscriptions, then no new snapshots are generated
|
||||
* and so no need to force slot deletion.
|
||||
*/
|
||||
if (n_active_slots != 0)
|
||||
{
|
||||
for (int i = 0; i < max_replication_slots; i++)
|
||||
{
|
||||
if (slots[i].dropped)
|
||||
{
|
||||
elog(LOG, "Drop logical replication slot because it was not update more than %ld seconds",
|
||||
(now - slots[i].last_updated)/USECS_PER_SEC);
|
||||
ReplicationSlotDrop(slots[i].name.data, true);
|
||||
slots[i].dropped = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_PG_init(void)
|
||||
{
|
||||
@@ -175,8 +44,6 @@ _PG_init(void)
|
||||
pg_init_libpagestore();
|
||||
pg_init_walproposer();
|
||||
|
||||
InitLogicalReplicationMonitor();
|
||||
|
||||
InitControlPlaneConnector();
|
||||
|
||||
pg_init_extension_server();
|
||||
|
||||
@@ -19,7 +19,6 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
consumption_metrics.workspace = true
|
||||
dashmap.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
git-version.workspace = true
|
||||
hashbrown.workspace = true
|
||||
@@ -60,8 +59,6 @@ scopeguard.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
sha2.workspace = true
|
||||
smol_str.workspace = true
|
||||
smallvec.workspace = true
|
||||
socket2.workspace = true
|
||||
sync_wrapper.workspace = true
|
||||
task-local-extensions.workspace = true
|
||||
@@ -78,7 +75,6 @@ tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
tracing.workspace = true
|
||||
url.workspace = true
|
||||
urlencoding.workspace = true
|
||||
utils.workspace = true
|
||||
uuid.workspace = true
|
||||
webpki-roots.workspace = true
|
||||
@@ -87,6 +83,7 @@ native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
smol_str.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
|
||||
ComputeUserInfoParseError, IpPattern,
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
@@ -15,12 +14,8 @@ use password_hack::PasswordHackPayload;
|
||||
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
use tokio::time::error::Elapsed;
|
||||
|
||||
use crate::{
|
||||
console,
|
||||
error::{ReportableError, UserFacingError},
|
||||
};
|
||||
use crate::{console, error::UserFacingError};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -36,6 +31,9 @@ pub enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
||||
|
||||
#[error(transparent)]
|
||||
WakeCompute(#[from] console::errors::WakeComputeError),
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
@@ -69,9 +67,6 @@ pub enum AuthErrorImpl {
|
||||
|
||||
#[error("Too many connections to this endpoint. Please try again later.")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("Authentication timed out")]
|
||||
UserTimeout(Elapsed),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -98,10 +93,6 @@ impl AuthError {
|
||||
pub fn is_auth_failed(&self) -> bool {
|
||||
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
|
||||
}
|
||||
|
||||
pub fn user_timeout(elapsed: Elapsed) -> Self {
|
||||
AuthErrorImpl::UserTimeout(elapsed).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
||||
@@ -116,6 +107,7 @@ impl UserFacingError for AuthError {
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
WakeCompute(e) => e.to_string_client(),
|
||||
Sasl(e) => e.to_string_client(),
|
||||
AuthFailed(_) => self.to_string(),
|
||||
BadAuthMethod(_) => self.to_string(),
|
||||
@@ -124,26 +116,6 @@ impl UserFacingError for AuthError {
|
||||
Io(_) => "Internal error".to_string(),
|
||||
IpAddressNotAllowed => self.to_string(),
|
||||
TooManyConnections => self.to_string(),
|
||||
UserTimeout(_) => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for AuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.get_error_kind(),
|
||||
GetAuthInfo(e) => e.get_error_kind(),
|
||||
Sasl(e) => e.get_error_kind(),
|
||||
AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
MissingEndpointName => crate::error::ErrorKind::User,
|
||||
Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
IpAddressNotAllowed => crate::error::ErrorKind::User,
|
||||
TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{AuthSecret, NodeInfo};
|
||||
use crate::console::AuthSecret;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
@@ -26,6 +26,7 @@ use crate::{
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use futures::TryFutureExt;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
@@ -55,11 +56,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum BackendType<'a, T, D> {
|
||||
pub enum BackendType<'a, T> {
|
||||
/// Cloud API (V2).
|
||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(MaybeOwned<'a, url::ApiUrl>, D),
|
||||
Link(MaybeOwned<'a, url::ApiUrl>),
|
||||
}
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
@@ -67,10 +68,9 @@ pub trait TestBackend: Send + Sync + 'static {
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
@@ -85,50 +85,51 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
#[cfg(test)]
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, D> BackendType<'_, T, D> {
|
||||
impl<T> BackendType<'_, T> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
|
||||
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
|
||||
Link(c) => Link(MaybeOwned::Borrowed(c)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, D> BackendType<'a, T, D> {
|
||||
impl<'a, T> BackendType<'a, T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
Link(c, x) => Link(c, x),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
Link(c, x) => Ok(Link(c, x)),
|
||||
Link(c) => Link(c),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComputeCredentials {
|
||||
impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
Link(c) => Ok(Link(c)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComputeCredentials<T> {
|
||||
pub info: ComputeUserInfo,
|
||||
pub keys: ComputeCredentialKeys,
|
||||
pub keys: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -151,6 +152,7 @@ impl ComputeUserInfo {
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
}
|
||||
@@ -185,21 +187,19 @@ async fn auth_quirks(
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let (info, unauthenticated_password) = match user_info.try_into() {
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
|
||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
||||
.await?;
|
||||
|
||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
|
||||
let password = match res.keys {
|
||||
ComputeCredentialKeys::Password(p) => p,
|
||||
_ => unreachable!("password hack should return a password"),
|
||||
};
|
||||
(res.info, Some(password))
|
||||
|
||||
(res.info, Some(res.keys))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
@@ -253,7 +253,7 @@ async fn authenticate_with_secret(
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
||||
let keys = match auth_outcome {
|
||||
@@ -275,22 +275,21 @@ async fn authenticate_with_secret(
|
||||
// Perform cleartext auth if we're allowed to do that.
|
||||
// Currently, we use it for websocket connections (latency).
|
||||
if allow_cleartext {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
|
||||
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
|
||||
}
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
||||
}
|
||||
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Link(_, _) => Some("link".into()),
|
||||
Link(_) => Some("link".into()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,7 +299,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => &user_info.user,
|
||||
Link(_, _) => "link",
|
||||
Link(_) => "link",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +311,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
|
||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
@@ -323,17 +322,33 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials =
|
||||
let compute_credentials =
|
||||
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
|
||||
BackendType::Console(api, credentials)
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node =
|
||||
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
|
||||
|
||||
ctx.set_project(node.aux.clone());
|
||||
|
||||
match compute_credentials.keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
|
||||
};
|
||||
|
||||
(node, BackendType::Console(api, compute_credentials.info))
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url, _) => {
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
let info = link::authenticate(ctx, &url, client).await?;
|
||||
let node_info = link::authenticate(ctx, &url, client).await?;
|
||||
|
||||
BackendType::Link(url, info)
|
||||
(
|
||||
CachedNodeInfo::new_uncached(node_info),
|
||||
BackendType::Link(url),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -342,18 +357,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
pub async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<'_, ComputeUserInfo> {
|
||||
pub async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
@@ -361,51 +365,21 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
async fn wake_compute(
|
||||
/// When applicable, wake the compute node, gaining its connection info in the process.
|
||||
/// The link auth flow doesn't support this, so we return [`None`] in that case.
|
||||
pub async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
compute,
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
metrics::LatencyTimer,
|
||||
sasl,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
@@ -12,12 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
@@ -27,11 +27,13 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret, &mut *ctx);
|
||||
let scram = auth::Scram(&secret);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
async {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = latency_timer.pause();
|
||||
|
||||
flow.begin(scram).await.map_err(|error| {
|
||||
warn!(?error, "error sending scram acknowledgement");
|
||||
@@ -43,9 +45,9 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
.map_err(|error| {
|
||||
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
|
||||
auth::AuthError::user_timeout(e)
|
||||
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
|
||||
})??;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::{
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
metrics::LatencyTimer,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
@@ -16,16 +16,15 @@ use tracing::{info, warn};
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn authenticate_cleartext(
|
||||
ctx: &mut RequestMonitoring,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer.pause();
|
||||
let _paused = latency_timer.pause();
|
||||
|
||||
let auth_outcome = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword(secret))
|
||||
@@ -48,15 +47,14 @@ pub async fn authenticate_cleartext(
|
||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||
pub async fn password_hack_no_authentication(
|
||||
ctx: &mut RequestMonitoring,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer.pause();
|
||||
let _paused = latency_timer.pause();
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
@@ -73,6 +71,6 @@ pub async fn password_hack_no_authentication(
|
||||
options: info.options,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password(payload.password),
|
||||
keys: payload.password,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
error::UserFacingError,
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
};
|
||||
@@ -14,6 +14,10 @@ use tracing::{info, info_span};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LinkAuthError {
|
||||
/// Authentication error reported by the console.
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthFailed(String),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
@@ -26,16 +30,10 @@ pub enum LinkAuthError {
|
||||
|
||||
impl UserFacingError for LinkAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Internal error".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for LinkAuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
use LinkAuthError::*;
|
||||
match self {
|
||||
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
|
||||
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
|
||||
LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
AuthFailed(_) => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -61,8 +59,6 @@ pub(super) async fn authenticate(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Web);
|
||||
|
||||
// registering waiter can fail if we get unlucky with rng.
|
||||
// just try again.
|
||||
let (psql_session_id, waiter) = loop {
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param,
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
|
||||
proxy::NeonOptions,
|
||||
serverless::SERVERLESS_DRIVER_SNI,
|
||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
|
||||
EndpointId, RoleName,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
@@ -43,12 +39,6 @@ pub enum ComputeUserInfoParseError {
|
||||
|
||||
impl UserFacingError for ComputeUserInfoParseError {}
|
||||
|
||||
impl ReportableError for ComputeUserInfoParseError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -99,9 +89,6 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
// record the values if we have them
|
||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||
ctx.set_user(user.clone());
|
||||
if let Some(dbname) = params.get("database") {
|
||||
ctx.set_dbname(dbname.into());
|
||||
}
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let endpoint_option = params
|
||||
|
||||
@@ -4,11 +4,9 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
sasl, scram,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -25,7 +23,7 @@ pub trait AuthMethod {
|
||||
pub struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
@@ -140,11 +138,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer.pause();
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
@@ -155,15 +148,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
match sasl.method {
|
||||
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => {
|
||||
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let secret = self.state.0;
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
@@ -180,7 +167,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn validate_password_and_exchange(
|
||||
pub(super) fn validate_password_and_exchange(
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
|
||||
@@ -240,9 +240,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
?unexpected,
|
||||
"unexpected startup packet, rejecting connection"
|
||||
);
|
||||
stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
|
||||
.await?
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -274,10 +272,5 @@ async fn handle_client(
|
||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
|
||||
// doesn't yet matter as pg-sni-router doesn't report analytics logs
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
|
||||
proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await
|
||||
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
use proxy::config::AuthenticationConfig;
|
||||
use proxy::config::CacheOptions;
|
||||
use proxy::config::HttpConfig;
|
||||
@@ -14,7 +12,6 @@ use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::RateLimiterConfig;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::redis::publisher::RedisPublisherClient;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
|
||||
@@ -25,7 +22,6 @@ use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
@@ -92,9 +88,6 @@ struct ProxyCliArgs {
|
||||
/// path to directory with TLS certificates for client postgres connections
|
||||
#[clap(long)]
|
||||
certs_dir: Option<String>,
|
||||
/// timeout for the TLS handshake
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
handshake_timeout: tokio::time::Duration,
|
||||
/// http endpoint to receive periodic metric updates
|
||||
#[clap(long)]
|
||||
metric_collection_endpoint: Option<String>,
|
||||
@@ -133,9 +126,6 @@ struct ProxyCliArgs {
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
||||
#[clap(long, default_value_t = 100)]
|
||||
initial_limit: usize,
|
||||
@@ -175,10 +165,6 @@ struct SqlOverHttpArgs {
|
||||
#[clap(long, default_value_t = 20)]
|
||||
sql_over_http_pool_max_conns_per_endpoint: usize,
|
||||
|
||||
/// How many connections to pool for each endpoint. Excess connections are discarded
|
||||
#[clap(long, default_value_t = 20000)]
|
||||
sql_over_http_pool_max_total_conns: usize,
|
||||
|
||||
/// How long pooled connections should remain idle for before closing
|
||||
#[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
|
||||
sql_over_http_idle_timeout: tokio::time::Duration,
|
||||
@@ -232,19 +218,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
|
||||
let cancel_map = CancelMap::default();
|
||||
let redis_publisher = match &args.redis_notifications {
|
||||
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
|
||||
url,
|
||||
args.region.clone(),
|
||||
&config.redis_rps_limit,
|
||||
)?))),
|
||||
None => None,
|
||||
};
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
cancel_map.clone(),
|
||||
redis_publisher,
|
||||
));
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
@@ -254,7 +227,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
@@ -269,7 +241,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -293,12 +264,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(url) = args.redis_notifications {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
url.to_owned(),
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
@@ -410,7 +376,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
}
|
||||
AuthBackend::Link => {
|
||||
let url = args.uri.parse()?;
|
||||
auth::BackendType::Link(MaybeOwned::Owned(url), ())
|
||||
auth::BackendType::Link(MaybeOwned::Owned(url))
|
||||
}
|
||||
};
|
||||
let http_config = HttpConfig {
|
||||
@@ -421,7 +387,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
|
||||
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
|
||||
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
|
||||
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
|
||||
},
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
@@ -430,8 +395,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||
let mut redis_rps_limit = args.redis_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut redis_rps_limit)?;
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
@@ -443,8 +406,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
endpoint_rps_limit,
|
||||
redis_rps_limit,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
// TODO: add this argument
|
||||
region: args.region.clone(),
|
||||
}));
|
||||
|
||||
@@ -1,86 +1,25 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Context;
|
||||
use dashmap::DashMap;
|
||||
use pq_proto::CancelKeyData;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
|
||||
redis::publisher::RedisPublisherClient,
|
||||
};
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
///
|
||||
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler {
|
||||
map: CancelMap,
|
||||
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
|
||||
}
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CancelError {
|
||||
#[error("{0}")]
|
||||
IO(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for CancelError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
CancelError::IO(_) => crate::error::ErrorKind::Compute,
|
||||
CancelError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CancellationHandler {
|
||||
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
|
||||
Self { map, redis_client }
|
||||
}
|
||||
impl CancelMap {
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> Result<(), CancelError> {
|
||||
let from = "from_client";
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
if let Some(redis_client) = &self.redis_client {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "not_found"])
|
||||
.inc();
|
||||
info!("publishing cancellation key to Redis");
|
||||
match redis_client.lock().await.try_publish(key, session_id).await {
|
||||
Ok(()) => {
|
||||
info!("cancellation key successfuly published to Redis");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
return Err(CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "found"])
|
||||
.inc();
|
||||
let cancel_closure = self
|
||||
.0
|
||||
.get(&key)
|
||||
.and_then(|x| x.clone())
|
||||
.with_context(|| format!("query cancellation key not found: {key}"))?;
|
||||
|
||||
info!("cancelling query per user's request using key {key}");
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
@@ -97,7 +36,7 @@ impl CancellationHandler {
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
match self.map.entry(key) {
|
||||
match self.0.entry(key) {
|
||||
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||
e.insert(None);
|
||||
@@ -109,46 +48,18 @@ impl CancellationHandler {
|
||||
info!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
cancellation_handler: self,
|
||||
cancel_map: self,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn contains(&self, session: &Session) -> bool {
|
||||
self.map.contains_key(&session.key)
|
||||
self.0.contains_key(&session.key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait NotificationsCancellationHandler {
|
||||
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl NotificationsCancellationHandler for CancellationHandler {
|
||||
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
||||
let from = "from_redis";
|
||||
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
|
||||
match cancel_closure {
|
||||
Some(cancel_closure) => {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "found"])
|
||||
.inc();
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
None => {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "not_found"])
|
||||
.inc();
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
self.0.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +81,7 @@ impl CancelClosure {
|
||||
}
|
||||
|
||||
/// Cancels the query running on user's compute node.
|
||||
async fn try_cancel_query(self) -> Result<(), CancelError> {
|
||||
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
|
||||
|
||||
@@ -183,7 +94,7 @@ pub struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -191,9 +102,7 @@ impl Session {
|
||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancellation_handler
|
||||
.map
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
self.cancel_map.0.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
}
|
||||
@@ -201,7 +110,7 @@ impl Session {
|
||||
|
||||
impl Drop for Session {
|
||||
fn drop(&mut self) {
|
||||
self.cancellation_handler.map.remove(&self.key);
|
||||
self.cancel_map.0.remove(&self.key);
|
||||
info!("dropped query cancellation key {}", &self.key);
|
||||
}
|
||||
}
|
||||
@@ -212,16 +121,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
let cancellation_handler = Arc::new(CancellationHandler {
|
||||
map: CancelMap::default(),
|
||||
redis_client: None,
|
||||
});
|
||||
let cancel_map: Arc<CancelMap> = Default::default();
|
||||
|
||||
let session = cancellation_handler.clone().get_session();
|
||||
assert!(cancellation_handler.contains(&session));
|
||||
let session = cancel_map.clone().get_session();
|
||||
assert!(cancel_map.contains(&session));
|
||||
drop(session);
|
||||
// Check that the session has been dropped.
|
||||
assert!(cancellation_handler.is_empty());
|
||||
assert!(cancel_map.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
|
||||
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::neon_option,
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
@@ -62,20 +58,6 @@ impl UserFacingError for ConnectionError {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ConnectionError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
@@ -93,7 +75,7 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: Self) {
|
||||
pub fn reuse_password(&mut self, other: &Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
@@ -253,8 +235,6 @@ pub struct PostgresConnection {
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub cancel_closure: CancelClosure,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
_guage: IntCounterPairGuard,
|
||||
}
|
||||
@@ -265,7 +245,6 @@ impl ConnCfg {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
aux: MetricsAuxInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
@@ -300,7 +279,6 @@ impl ConnCfg {
|
||||
stream,
|
||||
params,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: NUM_DB_CONNECTIONS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard(),
|
||||
|
||||
@@ -13,7 +13,7 @@ use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, (), ()>,
|
||||
pub auth_backend: auth::BackendType<'static, ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
@@ -21,9 +21,7 @@ pub struct ProxyConfig {
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
pub redis_rps_limit: Vec<RateBucketInfo>,
|
||||
pub region: String,
|
||||
pub handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -4,10 +4,7 @@ pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
IpPattern,
|
||||
},
|
||||
auth::{backend::ComputeUserInfo, IpPattern},
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
@@ -23,7 +20,7 @@ use tracing::info;
|
||||
|
||||
pub mod errors {
|
||||
use crate::{
|
||||
error::{io_error, ReportableError, UserFacingError},
|
||||
error::{io_error, UserFacingError},
|
||||
http,
|
||||
proxy::retry::ShouldRetry,
|
||||
};
|
||||
@@ -84,15 +81,6 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ApiError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ShouldRetry for ApiError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
@@ -162,16 +150,6 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for GetAuthInfoError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
@@ -216,16 +194,6 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for WakeComputeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
WakeComputeError::ApiError(e) => e.get_error_kind(),
|
||||
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
@@ -264,34 +232,6 @@ pub struct NodeInfo {
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
|
||||
@@ -176,7 +176,9 @@ impl super::Api for Api {
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
self.do_wake_compute()
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -188,7 +188,6 @@ impl super::Api for Api {
|
||||
ep,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
// When we just got a secret, we don't need to invalidate it.
|
||||
Ok(Cached::new_uncached(auth_info.secret))
|
||||
@@ -222,7 +221,6 @@ impl super::Api for Api {
|
||||
self.caches
|
||||
.project_info
|
||||
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok((
|
||||
Cached::new_uncached(allowed_ips),
|
||||
|
||||
@@ -8,10 +8,8 @@ use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
console::messages::MetricsAuxInfo,
|
||||
error::ErrorKind,
|
||||
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
|
||||
BranchId, DbName, EndpointId, ProjectId, RoleName,
|
||||
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
|
||||
EndpointId, ProjectId, RoleName,
|
||||
};
|
||||
|
||||
pub mod parquet;
|
||||
@@ -34,11 +32,9 @@ pub struct RequestMonitoring {
|
||||
project: Option<ProjectId>,
|
||||
branch: Option<BranchId>,
|
||||
endpoint_id: Option<EndpointId>,
|
||||
dbname: Option<DbName>,
|
||||
user: Option<RoleName>,
|
||||
application: Option<SmolStr>,
|
||||
error_kind: Option<ErrorKind>,
|
||||
pub(crate) auth_method: Option<AuthMethod>,
|
||||
success: bool,
|
||||
|
||||
// extra
|
||||
@@ -47,15 +43,6 @@ pub struct RequestMonitoring {
|
||||
pub latency_timer: LatencyTimer,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AuthMethod {
|
||||
// aka link aka passwordless
|
||||
Web,
|
||||
ScramSha256,
|
||||
ScramSha256Plus,
|
||||
Cleartext,
|
||||
}
|
||||
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
@@ -73,11 +60,9 @@ impl RequestMonitoring {
|
||||
project: None,
|
||||
branch: None,
|
||||
endpoint_id: None,
|
||||
dbname: None,
|
||||
user: None,
|
||||
application: None,
|
||||
error_kind: None,
|
||||
auth_method: None,
|
||||
success: false,
|
||||
|
||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||
@@ -104,10 +89,6 @@ impl RequestMonitoring {
|
||||
self.project = Some(x.project_id);
|
||||
}
|
||||
|
||||
pub fn set_project_id(&mut self, project_id: ProjectId) {
|
||||
self.project = Some(project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
|
||||
crate::metrics::CONNECTING_ENDPOINTS
|
||||
.with_label_values(&[self.protocol])
|
||||
@@ -119,30 +100,10 @@ impl RequestMonitoring {
|
||||
self.application = app.or_else(|| self.application.clone());
|
||||
}
|
||||
|
||||
pub fn set_dbname(&mut self, dbname: DbName) {
|
||||
self.dbname = Some(dbname);
|
||||
}
|
||||
|
||||
pub fn set_user(&mut self, user: RoleName) {
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
|
||||
self.auth_method = Some(auth_method);
|
||||
}
|
||||
|
||||
pub fn set_error_kind(&mut self, kind: ErrorKind) {
|
||||
ERROR_BY_KIND
|
||||
.with_label_values(&[kind.to_metric_label()])
|
||||
.inc();
|
||||
if let Some(ep) = &self.endpoint_id {
|
||||
ENDPOINT_ERRORS_BY_KIND
|
||||
.with_label_values(&[kind.to_metric_label()])
|
||||
.measure(ep);
|
||||
}
|
||||
self.error_kind = Some(kind);
|
||||
}
|
||||
|
||||
pub fn set_success(&mut self) {
|
||||
self.success = true;
|
||||
}
|
||||
|
||||
@@ -84,10 +84,8 @@ struct RequestData {
|
||||
username: Option<String>,
|
||||
application_name: Option<String>,
|
||||
endpoint_id: Option<String>,
|
||||
database: Option<String>,
|
||||
project: Option<String>,
|
||||
branch: Option<String>,
|
||||
auth_method: Option<&'static str>,
|
||||
error: Option<&'static str>,
|
||||
/// Success is counted if we form a HTTP response with sql rows inside
|
||||
/// Or if we make it to proxy_pass
|
||||
@@ -106,18 +104,11 @@ impl From<RequestMonitoring> for RequestData {
|
||||
username: value.user.as_deref().map(String::from),
|
||||
application_name: value.application.as_deref().map(String::from),
|
||||
endpoint_id: value.endpoint_id.as_deref().map(String::from),
|
||||
database: value.dbname.as_deref().map(String::from),
|
||||
project: value.project.as_deref().map(String::from),
|
||||
branch: value.branch.as_deref().map(String::from),
|
||||
auth_method: value.auth_method.as_ref().map(|x| match x {
|
||||
super::AuthMethod::Web => "web",
|
||||
super::AuthMethod::ScramSha256 => "scram_sha_256",
|
||||
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
|
||||
super::AuthMethod::Cleartext => "cleartext",
|
||||
}),
|
||||
protocol: value.protocol,
|
||||
region: value.region,
|
||||
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
||||
error: value.error_kind.as_ref().map(|e| e.to_str()),
|
||||
success: value.success,
|
||||
duration_us: SystemTime::from(value.first_packet)
|
||||
.elapsed()
|
||||
@@ -440,10 +431,8 @@ mod tests {
|
||||
application_name: Some("test".to_owned()),
|
||||
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
||||
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
auth_method: None,
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
error: None,
|
||||
@@ -516,15 +505,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1313727, 3, 6000),
|
||||
(1313720, 3, 6000),
|
||||
(1313780, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(1313867, 3, 6000),
|
||||
(1313709, 3, 6000),
|
||||
(1313501, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(438118, 1, 2000)
|
||||
(1087635, 3, 6000),
|
||||
(1087288, 3, 6000),
|
||||
(1087444, 3, 6000),
|
||||
(1087572, 3, 6000),
|
||||
(1087468, 3, 6000),
|
||||
(1087500, 3, 6000),
|
||||
(1087533, 3, 6000),
|
||||
(1087566, 3, 6000),
|
||||
(362671, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -554,11 +543,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1219459, 5, 10000),
|
||||
(1225609, 5, 10000),
|
||||
(1227403, 5, 10000),
|
||||
(1226765, 5, 10000),
|
||||
(1218043, 5, 10000)
|
||||
(1028637, 5, 10000),
|
||||
(1031969, 5, 10000),
|
||||
(1019900, 5, 10000),
|
||||
(1020365, 5, 10000),
|
||||
(1025010, 5, 10000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -590,11 +579,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1205106, 5, 10000),
|
||||
(1204837, 5, 10000),
|
||||
(1205130, 5, 10000),
|
||||
(1205118, 5, 10000),
|
||||
(1205373, 5, 10000)
|
||||
(1210770, 6, 12000),
|
||||
(1211036, 6, 12000),
|
||||
(1210990, 6, 12000),
|
||||
(1210861, 6, 12000),
|
||||
(202073, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -619,15 +608,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1313727, 3, 6000),
|
||||
(1313720, 3, 6000),
|
||||
(1313780, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(1313867, 3, 6000),
|
||||
(1313709, 3, 6000),
|
||||
(1313501, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(438118, 1, 2000)
|
||||
(1087635, 3, 6000),
|
||||
(1087288, 3, 6000),
|
||||
(1087444, 3, 6000),
|
||||
(1087572, 3, 6000),
|
||||
(1087468, 3, 6000),
|
||||
(1087500, 3, 6000),
|
||||
(1087533, 3, 6000),
|
||||
(1087566, 3, 6000),
|
||||
(362671, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -664,7 +653,7 @@ mod tests {
|
||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)],
|
||||
[(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
|
||||
@@ -17,7 +17,7 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
|
||||
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
|
||||
/// is way too convenient and tends to proliferate all across the codebase,
|
||||
/// ultimately leading to accidental leaks of sensitive data.
|
||||
pub trait UserFacingError: ReportableError {
|
||||
pub trait UserFacingError: fmt::Display {
|
||||
/// Format the error for client, stripping all sensitive info.
|
||||
///
|
||||
/// Although this might be a no-op for many types, it's highly
|
||||
@@ -29,13 +29,13 @@ pub trait UserFacingError: ReportableError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone)]
|
||||
pub enum ErrorKind {
|
||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||
User,
|
||||
|
||||
/// Network error between user and proxy. Not necessarily user error
|
||||
ClientDisconnect,
|
||||
Disconnect,
|
||||
|
||||
/// Proxy self-imposed rate limits
|
||||
RateLimit,
|
||||
@@ -46,9 +46,6 @@ pub enum ErrorKind {
|
||||
/// Error communicating with control plane
|
||||
ControlPlane,
|
||||
|
||||
/// Postgres error
|
||||
Postgres,
|
||||
|
||||
/// Error communicating with compute
|
||||
Compute,
|
||||
}
|
||||
@@ -57,46 +54,11 @@ impl ErrorKind {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "request failed due to user error",
|
||||
ErrorKind::ClientDisconnect => "client disconnected",
|
||||
ErrorKind::Disconnect => "client disconnected",
|
||||
ErrorKind::RateLimit => "request cancelled due to rate limit",
|
||||
ErrorKind::Service => "internal service error",
|
||||
ErrorKind::ControlPlane => "non-retryable control plane error",
|
||||
ErrorKind::Postgres => "postgres error",
|
||||
ErrorKind::Compute => {
|
||||
"non-retryable compute connection error (or exhausted retry capacity)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_metric_label(&self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "user",
|
||||
ErrorKind::ClientDisconnect => "clientdisconnect",
|
||||
ErrorKind::RateLimit => "ratelimit",
|
||||
ErrorKind::Service => "service",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "postgres",
|
||||
ErrorKind::Compute => "compute",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReportableError: fmt::Display + Send + 'static {
|
||||
fn get_error_kind(&self) -> ErrorKind;
|
||||
}
|
||||
|
||||
impl ReportableError for tokio::time::error::Elapsed {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::RateLimit
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for tokio_postgres::error::Error {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
if self.as_db_error().is_some() {
|
||||
ErrorKind::Postgres
|
||||
} else {
|
||||
ErrorKind::Compute
|
||||
ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use ::metrics::{
|
||||
exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec,
|
||||
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge,
|
||||
register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec,
|
||||
IntCounterVec, IntGauge, IntGaugeVec,
|
||||
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge_vec, Histogram,
|
||||
HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGaugeVec,
|
||||
};
|
||||
use metrics::{register_int_counter_pair, IntCounterPair};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::time;
|
||||
@@ -114,53 +112,6 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static HTTP_CONTENT_LENGTH: Lazy<Histogram> = Lazy::new(|| {
|
||||
register_histogram!(
|
||||
"proxy_http_conn_content_length_bytes",
|
||||
"Time it took for proxy to establish a connection to the compute endpoint",
|
||||
// largest bucket = 3^16 * 0.05ms = 2.15s
|
||||
exponential_buckets(8.0, 2.0, 20).unwrap()
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
|
||||
register_histogram!(
|
||||
"proxy_http_pool_reclaimation_lag_seconds",
|
||||
"Time it takes to reclaim unused connection pools",
|
||||
// 1us -> 65ms
|
||||
exponential_buckets(1e-6, 2.0, 16).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
|
||||
register_int_counter_pair!(
|
||||
"proxy_http_pool_endpoints_registered_total",
|
||||
"Number of endpoints we have registered pools for",
|
||||
"proxy_http_pool_endpoints_unregistered_total",
|
||||
"Number of endpoints we have unregistered pools for",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
|
||||
register_int_gauge!(
|
||||
"proxy_http_pool_opened_connections",
|
||||
"Number of opened connections to a database.",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_cancellation_requests_total",
|
||||
"Number of cancellation requests (per found/not_found).",
|
||||
&["source", "kind"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LatencyTimer {
|
||||
// time since the stopwatch was started
|
||||
@@ -209,9 +160,8 @@ impl LatencyTimer {
|
||||
|
||||
pub fn success(&mut self) {
|
||||
// stop the stopwatch and record the time that we have accumulated
|
||||
if let Some(start) = self.start.take() {
|
||||
self.accumulated += start.elapsed();
|
||||
}
|
||||
let start = self.start.take().expect("latency timer should be started");
|
||||
self.accumulated += start.elapsed();
|
||||
|
||||
// success
|
||||
self.outcome = "success";
|
||||
@@ -284,22 +234,3 @@ pub static CONNECTING_ENDPOINTS: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static ERROR_BY_KIND: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_errors_total",
|
||||
"Number of errors by a given classification",
|
||||
&["type"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static ENDPOINT_ERRORS_BY_KIND: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
|
||||
register_hll_vec!(
|
||||
32,
|
||||
"proxy_endpoints_affected_by_errors",
|
||||
"Number of endpoints affected by errors of a given classification",
|
||||
&["type"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
mod tests;
|
||||
|
||||
pub mod connect_compute;
|
||||
mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod passthrough;
|
||||
pub mod retry;
|
||||
@@ -10,14 +9,13 @@ pub mod wake_compute;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
cancellation::{self, CancellationHandler},
|
||||
cancellation::{self, CancelMap},
|
||||
compute,
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
|
||||
protocol2::WithClientIp,
|
||||
proxy::handshake::{handshake, HandshakeData},
|
||||
proxy::{handshake::handshake, passthrough::proxy_pass},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
EndpointCacheKey,
|
||||
@@ -30,17 +28,14 @@ use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
|
||||
use self::{
|
||||
connect_compute::{connect_to_compute, TcpMechanism},
|
||||
passthrough::ProxyPassthrough,
|
||||
};
|
||||
use self::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
||||
|
||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
f: F,
|
||||
@@ -62,7 +57,6 @@ pub async fn task_main(
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -73,6 +67,7 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -80,7 +75,7 @@ pub async fn task_main(
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let session_span = info_span!(
|
||||
@@ -103,41 +98,22 @@ pub async fn task_main(
|
||||
bail!("missing required client IP");
|
||||
}
|
||||
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
|
||||
|
||||
socket
|
||||
.inner
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
|
||||
|
||||
let res = handle_client(
|
||||
handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
cancellation_handler,
|
||||
cancel_map,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
ctx.log();
|
||||
Err(e.into())
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
p.proxy_pass().await
|
||||
}
|
||||
}
|
||||
.await
|
||||
}
|
||||
.unwrap_or_else(move |e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -163,14 +139,14 @@ pub enum ClientMode {
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub fn allow_cleartext(&self) -> bool {
|
||||
fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||
ClientMode::Websockets { .. } => false,
|
||||
@@ -193,45 +169,14 @@ impl ClientMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] handshake::HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
protocol = ctx.protocol,
|
||||
"handling interactive connection from client"
|
||||
@@ -248,17 +193,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let pause = ctx.latency_timer.pause();
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls));
|
||||
let (mut stream, params) =
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(cancel_key_data, ctx.session_id)
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
}
|
||||
};
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
@@ -282,12 +221,12 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
.await?;
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
let (mut node_info, user_info) = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
@@ -302,20 +241,23 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return stream.throw_error(e).instrument(params_span).await?;
|
||||
return stream.throw_error(e).instrument(params_span).await;
|
||||
}
|
||||
};
|
||||
|
||||
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
|
||||
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism { params: ¶ms },
|
||||
node_info,
|
||||
&user_info,
|
||||
mode.allow_self_signed_compute(config),
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let session = cancellation_handler.get_session();
|
||||
let session = cancel_map.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
@@ -325,14 +267,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
req: _request_gauge,
|
||||
conn: _client_gauge,
|
||||
cancel: session,
|
||||
}))
|
||||
proxy_pass(ctx, stream, node.stream, aux).await
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
@@ -341,7 +276,7 @@ async fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
session: &cancellation::Session,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
) -> anyhow::Result<()> {
|
||||
// Register compute's query cancellation token and produce a new, unique one.
|
||||
// The new token (cancel_key_data) will be sent to the client.
|
||||
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use crate::{
|
||||
auth::backend::ComputeCredentialKeys,
|
||||
auth,
|
||||
compute::{self, PostgresConnection},
|
||||
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
|
||||
console::{self, errors::WakeComputeError},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::NUM_CONNECTION_FAILURES,
|
||||
proxy::{
|
||||
retry::{retry_after, ShouldRetry},
|
||||
@@ -21,7 +20,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
@@ -32,13 +31,28 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
|
||||
node_info.invalidate()
|
||||
node_info.invalidate().config
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node once.
|
||||
#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_to_compute_once(
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
||||
|
||||
node_info
|
||||
.config
|
||||
.connect(ctx, allow_self_signed_compute, timeout)
|
||||
.await
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConnectMechanism {
|
||||
type Connection;
|
||||
type ConnectError: ReportableError;
|
||||
type ConnectError;
|
||||
type Error: From<Self::ConnectError>;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
@@ -50,16 +64,6 @@ pub trait ConnectMechanism {
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
|
||||
}
|
||||
|
||||
pub struct TcpMechanism<'a> {
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub params: &'a StartupMessageParams,
|
||||
@@ -71,14 +75,13 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
type ConnectError = compute::ConnectionError;
|
||||
type Error = compute::ConnectionError;
|
||||
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
node_info.connect(ctx, timeout).await
|
||||
connect_to_compute_once(ctx, node_info, timeout).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
@@ -89,23 +92,16 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
pub async fn connect_to_compute<M: ConnectMechanism>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
allow_self_signed_compute: bool,
|
||||
mut node_info: console::CachedNodeInfo,
|
||||
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||
M::Error: From<WakeComputeError>,
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||
if let Some(keys) = user_info.get_keys() {
|
||||
node_info.set_keys(keys);
|
||||
}
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
@@ -122,30 +118,28 @@ where
|
||||
|
||||
error!(error = ?err, "could not connect to compute node");
|
||||
|
||||
let node_info = if !node_info.cached() {
|
||||
// If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if !err.should_retry(num_retries) {
|
||||
return Err(err.into());
|
||||
}
|
||||
node_info
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
ctx.latency_timer.cache_miss();
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
let mut num_retries = 1;
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
node_info
|
||||
match user_info {
|
||||
auth::BackendType::Console(api, info) => {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
|
||||
ctx.latency_timer.cache_miss();
|
||||
let config = invalidate_cache(node_info);
|
||||
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
|
||||
|
||||
node_info.config.reuse_password(&config);
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
}
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => {}
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
// this can error for a few reasons, for instance:
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
info!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use std::future::poll_fn;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransferState {
|
||||
Running(CopyBuffer),
|
||||
ShuttingDown(u64),
|
||||
Done(u64),
|
||||
}
|
||||
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<io::Result<u64>>
|
||||
where
|
||||
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut r = Pin::new(r);
|
||||
let mut w = Pin::new(w);
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf) => {
|
||||
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
ready!(w.as_mut().poll_shutdown(cx))?;
|
||||
*state = TransferState::Done(*count);
|
||||
}
|
||||
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn copy_bidirectional<A, B>(
|
||||
a: &mut A,
|
||||
b: &mut B,
|
||||
) -> Result<(u64, u64), std::io::Error>
|
||||
where
|
||||
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut a_to_b = TransferState::Running(CopyBuffer::new());
|
||||
let mut b_to_a = TransferState::Running(CopyBuffer::new());
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
|
||||
let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
|
||||
|
||||
// Early termination checks
|
||||
if let TransferState::Done(_) = a_to_b {
|
||||
if let TransferState::Running(buf) = &b_to_a {
|
||||
// Initiate shutdown
|
||||
b_to_a = TransferState::ShuttingDown(buf.amt);
|
||||
b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
|
||||
}
|
||||
}
|
||||
if let TransferState::Done(_) = b_to_a {
|
||||
if let TransferState::Running(buf) = &a_to_b {
|
||||
// Initiate shutdown
|
||||
a_to_b = TransferState::ShuttingDown(buf.amt);
|
||||
a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
|
||||
}
|
||||
}
|
||||
|
||||
// It is not a problem if ready! returns early ... (comment remains the same)
|
||||
let a_to_b = ready!(a_to_b_result);
|
||||
let b_to_a = ready!(b_to_a_result);
|
||||
|
||||
Poll::Ready(Ok((a_to_b, b_to_a)))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct CopyBuffer {
|
||||
read_done: bool,
|
||||
need_flush: bool,
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
amt: u64,
|
||||
buf: Box<[u8]>,
|
||||
}
|
||||
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
|
||||
|
||||
impl CopyBuffer {
|
||||
pub(super) fn new() -> Self {
|
||||
Self {
|
||||
read_done: false,
|
||||
need_flush: false,
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
amt: 0,
|
||||
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_fill_buf<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
let mut buf = ReadBuf::new(&mut me.buf);
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
me.read_done = me.cap == filled_len;
|
||||
me.cap = filled_len;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<io::Result<usize>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
|
||||
Poll::Pending => {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
res => res,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<io::Result<u64>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
loop {
|
||||
// If our buffer is empty, then we need to read some data to
|
||||
// continue.
|
||||
if self.pos == self.cap && !self.read_done {
|
||||
self.pos = 0;
|
||||
self.cap = 0;
|
||||
|
||||
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => {
|
||||
// Try flushing when the reader has no progress to avoid deadlock
|
||||
// when the reader depends on buffered writer.
|
||||
if self.need_flush {
|
||||
ready!(writer.as_mut().poll_flush(cx))?;
|
||||
self.need_flush = false;
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||
if i == 0 {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"write zero byte into writer",
|
||||
)));
|
||||
} else {
|
||||
self.pos += i;
|
||||
self.amt += i as u64;
|
||||
self.need_flush = true;
|
||||
}
|
||||
}
|
||||
|
||||
// If pos larger than cap, this loop will never stop.
|
||||
// In particular, user's wrong poll_write implementation returning
|
||||
// incorrect written length may lead to thread blocking.
|
||||
debug_assert!(
|
||||
self.pos <= self.cap,
|
||||
"writer returned length larger than input slice"
|
||||
);
|
||||
|
||||
// If we've written all the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
if self.pos == self.cap && self.read_done {
|
||||
ready!(writer.as_mut().poll_flush(cx))?;
|
||||
return Poll::Ready(Ok(self.amt));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_early_termination_a_to_d() {
|
||||
let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
a_mock.write_all(b"hello").await.unwrap();
|
||||
a_mock.shutdown().await.unwrap();
|
||||
d_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||
|
||||
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (a_to_d_count, d_to_a_count) = result;
|
||||
assert_eq!(a_to_d_count, 5); // 'hello' was transferred
|
||||
assert!(d_to_a_count <= 8); // response only partially transferred or not at all
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_early_termination_d_to_a() {
|
||||
let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
d_mock.write_all(b"hello").await.unwrap();
|
||||
d_mock.shutdown().await.unwrap();
|
||||
a_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||
|
||||
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (a_to_d_count, d_to_a_count) = result;
|
||||
assert_eq!(d_to_a_count, 5); // 'hello' was transferred
|
||||
assert!(a_to_d_count <= 8); // response only partially transferred or not at all
|
||||
}
|
||||
}
|
||||
@@ -1,60 +1,15 @@
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
|
||||
use thiserror::Error;
|
||||
use anyhow::{bail, Context};
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
config::TlsConfig,
|
||||
error::ReportableError,
|
||||
proxy::ERR_INSECURE_CONNECTION,
|
||||
stream::{PqStream, Stream, StreamUpgradeError},
|
||||
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum HandshakeError {
|
||||
#[error("data is sent before server replied with EncryptionResponse")]
|
||||
EarlyData,
|
||||
|
||||
#[error("protocol violation")]
|
||||
ProtocolViolation,
|
||||
|
||||
#[error("missing certificate")]
|
||||
MissingCertificate,
|
||||
|
||||
#[error("{0}")]
|
||||
StreamUpgradeError(#[from] StreamUpgradeError),
|
||||
|
||||
#[error("{0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for HandshakeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
HandshakeError::EarlyData => crate::error::ErrorKind::User,
|
||||
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
|
||||
// This error should not happen, but will if we have no default certificate and
|
||||
// the client sends no SNI extension.
|
||||
// If they provide SNI then we can be sure there is a certificate that matches.
|
||||
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
|
||||
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
|
||||
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
|
||||
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
},
|
||||
HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
HandshakeError::ReportedError(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
@@ -63,7 +18,8 @@ pub enum HandshakeData<S> {
|
||||
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
) -> Result<HandshakeData<S>, HandshakeError> {
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
@@ -93,14 +49,14 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
return Err(HandshakeError::EarlyData);
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.ok_or(HandshakeError::MissingCertificate)?;
|
||||
.context("missing certificate")?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
@@ -108,7 +64,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
@@ -117,23 +73,23 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
return stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
|
||||
.await?;
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||
}
|
||||
|
||||
info!(session_type = "normal", "successful handshake");
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
break Ok(Some((stream, params)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
break Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
use crate::{
|
||||
cancellation,
|
||||
compute::PostgresConnection,
|
||||
console::messages::MetricsAuxInfo,
|
||||
context::RequestMonitoring,
|
||||
metrics::NUM_BYTES_PROXIED_COUNTER,
|
||||
stream::Stream,
|
||||
usage_metrics::{Ids, USAGE_METRICS},
|
||||
};
|
||||
use metrics::IntCounterPairGuard;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
@@ -14,10 +11,14 @@ use utils::measured_stream::MeasuredStream;
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
ctx: &mut RequestMonitoring,
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id.clone(),
|
||||
branch_id: aux.branch_id.clone(),
|
||||
@@ -46,23 +47,7 @@ pub async fn proxy_pass(
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?;
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct ProxyPassthrough<S> {
|
||||
pub client: Stream<S>,
|
||||
pub compute: PostgresConnection,
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
pub req: IntCounterPairGuard,
|
||||
pub conn: IntCounterPairGuard,
|
||||
pub cancel: cancellation::Session,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub async fn proxy_pass(self) -> anyhow::Result<()> {
|
||||
proxy_pass(self.client, self.compute.stream, self.aux).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,19 +2,13 @@
|
||||
|
||||
mod mitm;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::ShouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
};
|
||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend};
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::caches::NodeInfoCache;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
@@ -150,7 +144,7 @@ impl TestAuth for Scram {
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0, &mut RequestMonitoring::test()))
|
||||
.begin(auth::Scram(&self.0))
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
@@ -169,11 +163,11 @@ async fn dummy_proxy(
|
||||
tls: Option<TlsConfig>,
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
let client = WithClientIp::new(client);
|
||||
let mut stream = match handshake(client, tls.as_ref()).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
|
||||
.await?
|
||||
.context("handshake failed")?;
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
@@ -381,7 +375,6 @@ enum ConnectAction {
|
||||
struct TestConnectMechanism {
|
||||
counter: Arc<std::sync::Mutex<usize>>,
|
||||
sequence: Vec<ConnectAction>,
|
||||
cache: &'static NodeInfoCache,
|
||||
}
|
||||
|
||||
impl TestConnectMechanism {
|
||||
@@ -400,12 +393,6 @@ impl TestConnectMechanism {
|
||||
Self {
|
||||
counter: Arc::new(std::sync::Mutex::new(0)),
|
||||
sequence,
|
||||
cache: Box::leak(Box::new(NodeInfoCache::new(
|
||||
"test",
|
||||
1,
|
||||
Duration::from_secs(100),
|
||||
false,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -416,13 +403,6 @@ struct TestConnection;
|
||||
#[derive(Debug)]
|
||||
struct TestConnectError {
|
||||
retryable: bool,
|
||||
kind: crate::error::ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportableError for TestConnectError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
self.kind
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestConnectError {
|
||||
@@ -456,14 +436,8 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Connect => Ok(TestConnection),
|
||||
ConnectAction::Retry => Err(TestConnectError {
|
||||
retryable: true,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
ConnectAction::Fail => Err(TestConnectError {
|
||||
retryable: false,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
ConnectAction::Retry => Err(TestConnectError { retryable: true }),
|
||||
ConnectAction::Fail => Err(TestConnectError { retryable: false }),
|
||||
x => panic!("expecting action {:?}, connect is called instead", x),
|
||||
}
|
||||
}
|
||||
@@ -477,7 +451,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
let action = self.sequence[*counter];
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info()),
|
||||
ConnectAction::WakeFail => {
|
||||
let err = console::errors::ApiError::Console {
|
||||
status: http::StatusCode::FORBIDDEN,
|
||||
@@ -504,46 +478,39 @@ impl TestBackend for TestConnectMechanism {
|
||||
{
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
fn helper_create_cached_node_info() -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
config: compute::ConnCfg::new(),
|
||||
aux: Default::default(),
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
let (_, node) = cache.insert("key".into(), node);
|
||||
node
|
||||
CachedNodeInfo::new_uncached(node)
|
||||
}
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> auth::BackendType<'static, ComputeCredentials, &()> {
|
||||
) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) {
|
||||
let cache = helper_create_cached_node_info();
|
||||
let user_info = auth::BackendType::Console(
|
||||
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
);
|
||||
user_info
|
||||
(cache, user_info)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
let mechanism = TestConnectMechanism::new(vec![Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -551,12 +518,11 @@ async fn connect_to_compute_success() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -565,12 +531,11 @@ async fn connect_to_compute_retry() {
|
||||
/// Test that we don't retry if the error is not retryable.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_1() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -579,12 +544,11 @@ async fn connect_to_compute_non_retry_1() {
|
||||
/// Even for non-retryable errors, we should retry at least once.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_2() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -593,16 +557,15 @@ async fn connect_to_compute_non_retry_2() {
|
||||
/// Retry for at most `NUM_RETRIES_CONNECT` times.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_3() {
|
||||
let _ = env_logger::try_init();
|
||||
assert_eq!(NUM_RETRIES_CONNECT, 16);
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![
|
||||
Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -611,12 +574,11 @@ async fn connect_to_compute_non_retry_3() {
|
||||
/// Should retry wake compute.
|
||||
#[tokio::test]
|
||||
async fn wake_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -625,12 +587,11 @@ async fn wake_retry() {
|
||||
/// Wake failed with a non-retryable error.
|
||||
#[tokio::test]
|
||||
async fn wake_non_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -35,10 +35,12 @@ async fn proxy_mitm(
|
||||
tokio::spawn(async move {
|
||||
// begin handshake with end_server
|
||||
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
|
||||
let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
};
|
||||
// process handshake with end_client
|
||||
let (end_client, startup) =
|
||||
handshake(client1, Some(&server_config1), &CancelMap::default())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::console::{
|
||||
errors::WakeComputeError,
|
||||
provider::{CachedNodeInfo, ConsoleBackend},
|
||||
Api,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
|
||||
use crate::proxy::retry::retry_after;
|
||||
@@ -6,16 +11,17 @@ use hyper::StatusCode;
|
||||
use std::ops::ControlFlow;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use super::connect_compute::ComputeConnectBackend;
|
||||
use super::retry::ShouldRetry;
|
||||
|
||||
pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||
/// wake a compute (or retrieve an existing compute session from cache)
|
||||
pub async fn wake_compute(
|
||||
num_retries: &mut u32,
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &B,
|
||||
api: &ConsoleBackend,
|
||||
info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
loop {
|
||||
let wake_res = api.wake_compute(ctx).await;
|
||||
let wake_res = api.wake_compute(ctx, info).await;
|
||||
match handle_try_wake(wake_res, *num_retries) {
|
||||
Err(e) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
|
||||
@@ -4,4 +4,4 @@ mod limiter;
|
||||
pub use aimd::Aimd;
|
||||
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
||||
pub use limiter::Limiter;
|
||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
|
||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo};
|
||||
|
||||
@@ -22,44 +22,6 @@ use super::{
|
||||
RateLimiterConfig,
|
||||
};
|
||||
|
||||
pub struct RedisRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: &'static [RateBucketInfo],
|
||||
}
|
||||
|
||||
impl RedisRateLimiter {
|
||||
pub fn new(info: &'static [RateBucketInfo]) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
RateBucket {
|
||||
start: Instant::now(),
|
||||
count: 0,
|
||||
};
|
||||
info.len()
|
||||
],
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections is below `max_rps` rps.
|
||||
pub fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let should_allow_request = self
|
||||
.data
|
||||
.iter_mut()
|
||||
.zip(self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
self.data.iter_mut().for_each(RateBucket::inc);
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
}
|
||||
}
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
//
|
||||
// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
pub mod notifications;
|
||||
pub mod publisher;
|
||||
|
||||
@@ -1,44 +1,38 @@
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use futures::StreamExt;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::aio::PubSub;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
cache::project_info::ProjectInfoCache,
|
||||
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
|
||||
intern::{ProjectIdInt, RoleNameInt},
|
||||
};
|
||||
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
|
||||
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
|
||||
struct RedisConsumerClient {
|
||||
struct ConsoleRedisClient {
|
||||
client: redis::Client,
|
||||
}
|
||||
|
||||
impl RedisConsumerClient {
|
||||
impl ConsoleRedisClient {
|
||||
pub fn new(url: &str) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
async fn try_connect(&self) -> anyhow::Result<PubSub> {
|
||||
let mut conn = self.client.get_async_connection().await?.into_pubsub();
|
||||
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
|
||||
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
|
||||
conn.subscribe(PROXY_CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
|
||||
conn.subscribe(CHANNEL_NAME).await?;
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
pub(crate) enum Notification {
|
||||
enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
@@ -51,25 +45,16 @@ pub(crate) enum Notification {
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
#[serde(rename = "/cancel_session")]
|
||||
Cancel(CancelSession),
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedIpsUpdate {
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct AllowedIpsUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct PasswordUpdate {
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct PasswordUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct CancelSession {
|
||||
pub region_id: Option<String>,
|
||||
pub cancel_key_data: CancelKeyData,
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
@@ -79,88 +64,6 @@ where
|
||||
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
||||
}
|
||||
|
||||
struct MessageHandler<
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||
> {
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<H>,
|
||||
region_id: String,
|
||||
}
|
||||
|
||||
impl<
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||
> MessageHandler<C, H>
|
||||
{
|
||||
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
|
||||
Self {
|
||||
cache,
|
||||
cancellation_handler,
|
||||
region_id,
|
||||
}
|
||||
}
|
||||
pub fn disable_ttl(&self) {
|
||||
self.cache.disable_ttl();
|
||||
}
|
||||
pub fn enable_ttl(&self) {
|
||||
self.cache.enable_ttl();
|
||||
}
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
|
||||
use Notification::*;
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Cancel(cancel_session) => {
|
||||
tracing::Span::current().record(
|
||||
"session_id",
|
||||
&tracing::field::display(cancel_session.session_id),
|
||||
);
|
||||
if let Some(cancel_region) = cancel_session.region_id {
|
||||
// If the message is not for this region, ignore it.
|
||||
if cancel_region != self.region_id {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
|
||||
match self
|
||||
.cancellation_handler
|
||||
.cancel_session_no_publish(cancel_session.cancel_key_data)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to cancel session: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
let cache = self.cache.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
use Notification::*;
|
||||
match msg {
|
||||
@@ -171,33 +74,50 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(cache))]
|
||||
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
invalidate_cache(cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg.clone());
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "console_notifications", skip_all)]
|
||||
pub async fn task_main<C>(
|
||||
url: String,
|
||||
cache: Arc<C>,
|
||||
cancel_map: CancelMap,
|
||||
region_id: String,
|
||||
) -> anyhow::Result<Infallible>
|
||||
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
cache.enable_ttl();
|
||||
let handler = MessageHandler::new(
|
||||
cache,
|
||||
Arc::new(CancellationHandler::new(cancel_map, None)),
|
||||
region_id,
|
||||
);
|
||||
|
||||
loop {
|
||||
let redis = RedisConsumerClient::new(&url)?;
|
||||
let redis = ConsoleRedisClient::new(&url)?;
|
||||
let conn = match redis.try_connect().await {
|
||||
Ok(conn) => {
|
||||
handler.disable_ttl();
|
||||
cache.disable_ttl();
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -210,7 +130,7 @@ where
|
||||
};
|
||||
let mut stream = conn.into_on_message();
|
||||
while let Some(msg) = stream.next().await {
|
||||
match handler.handle_message(msg).await {
|
||||
match handle_message(msg, cache.clone()) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
||||
@@ -218,7 +138,7 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
handler.enable_ttl();
|
||||
cache.enable_ttl();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,33 +198,6 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn parse_cancel_session() -> anyhow::Result<()> {
|
||||
let cancel_key_data = CancelKeyData {
|
||||
backend_pid: 42,
|
||||
cancel_key: 41,
|
||||
};
|
||||
let uuid = uuid::Uuid::new_v4();
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: None,
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result);
|
||||
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: Some("region".to_string()),
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result,);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::AsyncCommands;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
|
||||
|
||||
pub struct RedisPublisherClient {
|
||||
client: redis::Client,
|
||||
publisher: Option<redis::aio::Connection>,
|
||||
region_id: String,
|
||||
limiter: RedisRateLimiter,
|
||||
}
|
||||
|
||||
impl RedisPublisherClient {
|
||||
pub fn new(
|
||||
url: &str,
|
||||
region_id: String,
|
||||
info: &'static [RateBucketInfo],
|
||||
) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self {
|
||||
client,
|
||||
publisher: None,
|
||||
region_id,
|
||||
limiter: RedisRateLimiter::new(info),
|
||||
})
|
||||
}
|
||||
pub async fn try_publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping cancellation message");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
match self.publish(cancel_key_data, session_id).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
self.publisher = None;
|
||||
}
|
||||
}
|
||||
tracing::info!("Publisher is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.publish(cancel_key_data, session_id).await
|
||||
}
|
||||
|
||||
async fn publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
let conn = self
|
||||
.publisher
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
|
||||
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
|
||||
region_id: Some(self.region_id.clone()),
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
}))?;
|
||||
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.get_async_connection().await {
|
||||
Ok(conn) => {
|
||||
self.publisher = Some(conn);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ mod channel_binding;
|
||||
mod messages;
|
||||
mod stream;
|
||||
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::error::UserFacingError;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -48,18 +48,6 @@ impl UserFacingError for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for Error {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
Error::MissingBinding => crate::error::ErrorKind::Service,
|
||||
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenient result type for SASL exchange.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
//! Handles both SQL over HTTP and SQL over Websockets.
|
||||
|
||||
mod backend;
|
||||
mod conn_pool;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
@@ -19,12 +18,12 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
|
||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
@@ -50,19 +49,17 @@ pub async fn task_main(
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
}
|
||||
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let conn_pool = Arc::clone(&conn_pool);
|
||||
tokio::spawn(async move {
|
||||
conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||
});
|
||||
}
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(config);
|
||||
|
||||
let conn_pool2 = Arc::clone(&conn_pool);
|
||||
tokio::spawn(async move {
|
||||
conn_pool2.gc_worker(StdRng::from_entropy()).await;
|
||||
});
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
@@ -76,11 +73,6 @@ pub async fn task_main(
|
||||
}
|
||||
});
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
});
|
||||
|
||||
let tls_config = match config.tls_config.as_ref() {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
@@ -110,13 +102,14 @@ pub async fn task_main(
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
|
||||
let (io, _) = stream.get_ref();
|
||||
let (io, tls) = stream.get_ref();
|
||||
let client_addr = io.client_addr();
|
||||
let remote_addr = io.inner.remote_addr();
|
||||
let backend = backend.clone();
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
@@ -125,21 +118,24 @@ pub async fn task_main(
|
||||
};
|
||||
Ok(MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Body>| {
|
||||
let backend = backend.clone();
|
||||
let sni_name = sni_name.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
|
||||
async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend,
|
||||
tls_config,
|
||||
conn_pool,
|
||||
ws_connections,
|
||||
cancellation_handler,
|
||||
cancel_map,
|
||||
session_id,
|
||||
sni_name,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
@@ -204,10 +200,12 @@ where
|
||||
async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
tls: &'static TlsConfig,
|
||||
conn_pool: Arc<conn_pool::GlobalConnPool>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
sni_hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
@@ -227,13 +225,13 @@ async fn request_handler(
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
|
||||
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
ctx,
|
||||
&mut ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
cancel_map,
|
||||
host,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
@@ -248,9 +246,17 @@ async fn request_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend).await
|
||||
sql_over_http::handle(
|
||||
tls,
|
||||
&config.http_config,
|
||||
&mut ctx,
|
||||
request,
|
||||
sni_hostname,
|
||||
conn_pool,
|
||||
)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
|
||||
compute,
|
||||
config::ProxyConfig,
|
||||
console::{
|
||||
errors::{GetAuthInfoError, WakeComputeError},
|
||||
CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
proxy::connect_compute::ConnectMechanism,
|
||||
};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
|
||||
pub struct PoolingBackend {
|
||||
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub config: &'static ProxyConfig,
|
||||
}
|
||||
|
||||
impl PoolingBackend {
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = conn_info.user_info.clone();
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(AuthError::ip_address_not_allowed());
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => backend.get_role_secret(ctx).await?,
|
||||
};
|
||||
|
||||
let secret = match cached_secret.value.clone() {
|
||||
Some(secret) => secret,
|
||||
None => {
|
||||
// If we don't have an authentication secret, for the http flow we can just return an error.
|
||||
info!("authentication info not found");
|
||||
return Err(AuthError::auth_failed(&*user_info.user));
|
||||
}
|
||||
};
|
||||
let auth_outcome =
|
||||
crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => Ok(key),
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
||||
}
|
||||
};
|
||||
res.map(|key| ComputeCredentials {
|
||||
info: user_info,
|
||||
keys: key,
|
||||
})
|
||||
}
|
||||
|
||||
// Wake up the destination if needed. Code here is a bit involved because
|
||||
// we reuse the code from the usual proxy and we need to prepare few structures
|
||||
// that this code expects.
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub async fn connect_to_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
let maybe_client = if !force_new {
|
||||
info!("pool: looking for an existing connection");
|
||||
self.pool.get(ctx, &conn_info).await?
|
||||
} else {
|
||||
info!("pool: pool is disabled");
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(client) = maybe_client {
|
||||
return Ok(client);
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
},
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to compute")]
|
||||
ConnectionError(#[from] tokio_postgres::Error),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
#[error("user not authenticated")]
|
||||
AuthError(#[from] AuthError),
|
||||
#[error("wake_compute returned error")]
|
||||
WakeCompute(#[from] WakeComputeError),
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism {
|
||||
type Connection = Client<tokio_postgres::Client>;
|
||||
type ConnectError = tokio_postgres::Error;
|
||||
type Error = HttpConnError;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
let config = config
|
||||
.user(&self.conn_info.user_info.user)
|
||||
.password(&*self.conn_info.password)
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(timeout);
|
||||
|
||||
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
|
||||
|
||||
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
|
||||
Ok(poll_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
self.conn_info.clone(),
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,23 +9,23 @@ use tokio_postgres::Row;
|
||||
// as parameters.
|
||||
//
|
||||
pub fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter().map(json_value_to_pg_text).collect()
|
||||
}
|
||||
json.iter()
|
||||
.map(|value| {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.to_string()),
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.to_string()),
|
||||
|
||||
// special care for arrays
|
||||
Value::Array(_) => json_array_to_pg_array(value),
|
||||
}
|
||||
// special care for arrays
|
||||
Value::Array(_) => json_array_to_pg_array(value),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
//
|
||||
@@ -60,20 +60,6 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum JsonConversionError {
|
||||
#[error("internal error compute returned invalid data: {0}")]
|
||||
AsTextError(tokio_postgres::Error),
|
||||
#[error("parse int error: {0}")]
|
||||
ParseIntError(#[from] std::num::ParseIntError),
|
||||
#[error("parse float error: {0}")]
|
||||
ParseFloatError(#[from] std::num::ParseFloatError),
|
||||
#[error("parse json error: {0}")]
|
||||
ParseJsonError(#[from] serde_json::Error),
|
||||
#[error("unbalanced array")]
|
||||
UnbalancedArray,
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
@@ -82,7 +68,7 @@ pub fn pg_text_row_to_json(
|
||||
columns: &[Type],
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Value, JsonConversionError> {
|
||||
) -> Result<Value, anyhow::Error> {
|
||||
let iter = row
|
||||
.columns()
|
||||
.iter()
|
||||
@@ -90,7 +76,7 @@ pub fn pg_text_row_to_json(
|
||||
.enumerate()
|
||||
.map(|(i, (column, typ))| {
|
||||
let name = column.name();
|
||||
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
|
||||
let pg_value = row.as_text(i)?;
|
||||
let json_value = if raw_output {
|
||||
match pg_value {
|
||||
Some(v) => Value::String(v.to_string()),
|
||||
@@ -106,10 +92,10 @@ pub fn pg_text_row_to_json(
|
||||
// drop keys and aggregate into array
|
||||
let arr = iter
|
||||
.map(|r| r.map(|(_key, val)| val))
|
||||
.collect::<Result<Vec<Value>, JsonConversionError>>()?;
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter.collect::<Result<Map<String, Value>, JsonConversionError>>()?;
|
||||
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
@@ -117,7 +103,7 @@ pub fn pg_text_row_to_json(
|
||||
//
|
||||
// Convert postgres text-encoded value to JSON value
|
||||
//
|
||||
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, JsonConversionError> {
|
||||
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
if let Some(val) = pg_value {
|
||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||
return pg_array_parse(val, elem_type);
|
||||
@@ -156,7 +142,7 @@ fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, Json
|
||||
// values. Unlike postgres we don't check that all nested arrays have the same
|
||||
// dimensions, we just return them as is.
|
||||
//
|
||||
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, JsonConversionError> {
|
||||
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
|
||||
}
|
||||
|
||||
@@ -164,7 +150,7 @@ fn _pg_array_parse(
|
||||
pg_array: &str,
|
||||
elem_type: &Type,
|
||||
nested: bool,
|
||||
) -> Result<(Value, usize), JsonConversionError> {
|
||||
) -> Result<(Value, usize), anyhow::Error> {
|
||||
let mut pg_array_chr = pg_array.char_indices();
|
||||
let mut level = 0;
|
||||
let mut quote = false;
|
||||
@@ -184,7 +170,7 @@ fn _pg_array_parse(
|
||||
entry: &mut String,
|
||||
entries: &mut Vec<Value>,
|
||||
elem_type: &Type,
|
||||
) -> Result<(), JsonConversionError> {
|
||||
) -> Result<(), anyhow::Error> {
|
||||
if !entry.is_empty() {
|
||||
// While in usual postgres response we get nulls as None and everything else
|
||||
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
|
||||
@@ -248,7 +234,7 @@ fn _pg_array_parse(
|
||||
}
|
||||
|
||||
if level != 0 {
|
||||
return Err(JsonConversionError::UnbalancedArray);
|
||||
return Err(anyhow::anyhow!("unbalanced array"));
|
||||
}
|
||||
|
||||
Ok((Value::Array(entries), 0))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
@@ -12,7 +13,6 @@ use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tokio::join;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::error::ErrorPosition;
|
||||
use tokio_postgres::GenericClient;
|
||||
@@ -20,7 +20,6 @@ use tokio_postgres::IsolationLevel;
|
||||
use tokio_postgres::ReadyForQueryStatus;
|
||||
use tokio_postgres::Transaction;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::instrument;
|
||||
use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
@@ -28,31 +27,22 @@ use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::HttpConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::HTTP_CONTENT_LENGTH;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
use super::json::{json_to_pg_text, pg_text_row_to_json};
|
||||
use super::SERVERLESS_DRIVER_SNI;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
#[serde(deserialize_with = "bytes_to_pg_text")]
|
||||
params: Vec<Option<String>>,
|
||||
#[serde(default)]
|
||||
array_mode: Option<bool>,
|
||||
params: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
@@ -79,86 +69,67 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
// TODO: consider avoiding the allocation here.
|
||||
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static str),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing password")]
|
||||
MissingPassword,
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
#[error("malformed endpoint")]
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, ConnInfoError> {
|
||||
// HTTP only uses cleartext (for now and likely always)
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
|
||||
.ok_or(anyhow::anyhow!("missing connection string"))?
|
||||
.to_str()?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
return Err(anyhow::anyhow!(
|
||||
"connection string must start with postgres: or postgresql:"
|
||||
));
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
.ok_or(anyhow::anyhow!("missing database name"))?;
|
||||
|
||||
let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
let dbname = url_path
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
let username = RoleName::from(connection_url.username());
|
||||
if username.is_empty() {
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(ConnInfoError::MissingPassword)?;
|
||||
let password = urlencoding::decode_binary(password.as_bytes());
|
||||
.ok_or(anyhow::anyhow!("no password"))?;
|
||||
|
||||
// TLS certificate selector now based on SNI hostname, so if we are running here
|
||||
// we are sure that SNI hostname is set to one of the configured domain names.
|
||||
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
|
||||
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
.ok_or(ConnInfoError::MissingHostname)?;
|
||||
.ok_or(anyhow::anyhow!("no host"))?;
|
||||
|
||||
let endpoint =
|
||||
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
|
||||
let host_header = headers
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next());
|
||||
|
||||
// sni_hostname has to be either the same as hostname or the one used in serverless driver.
|
||||
if !check_matches(&sni_hostname, hostname)? {
|
||||
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
|
||||
} else if let Some(h) = host_header {
|
||||
if h != sni_hostname {
|
||||
return Err(anyhow::anyhow!("mismatched host header and hostname"));
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
@@ -180,35 +151,42 @@ fn get_conn_info(
|
||||
|
||||
Ok(ConnInfo {
|
||||
user_info,
|
||||
dbname,
|
||||
password: match password {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
},
|
||||
dbname: dbname.into(),
|
||||
password: password.into(),
|
||||
})
|
||||
}
|
||||
|
||||
fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
|
||||
if sni_hostname == hostname {
|
||||
return Ok(true);
|
||||
}
|
||||
let (sni_hostname_first, sni_hostname_rest) = sni_hostname
|
||||
.split_once('.')
|
||||
.ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
|
||||
let (_, hostname_rest) = hostname
|
||||
.split_once('.')
|
||||
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
|
||||
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.http_config.request_timeout,
|
||||
handle_inner(config, &mut ctx, request, backend),
|
||||
config.request_timeout,
|
||||
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
Ok(r) => match r {
|
||||
Ok(r) => {
|
||||
ctx.set_success();
|
||||
r
|
||||
}
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
// TODO: ctx.set_error_kind(e.get_error_type());
|
||||
|
||||
let mut message = format!("{:?}", e);
|
||||
let db_error = e
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
@@ -284,12 +262,10 @@ pub async fn handle(
|
||||
)?
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
|
||||
Err(_) => {
|
||||
let message = format!(
|
||||
"HTTP-Connection timed out, execution time exeeded {} seconds",
|
||||
config.http_config.request_timeout.as_secs()
|
||||
config.request_timeout.as_secs()
|
||||
);
|
||||
error!(message);
|
||||
json_response(
|
||||
@@ -298,7 +274,6 @@ pub async fn handle(
|
||||
)?
|
||||
}
|
||||
};
|
||||
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
@@ -306,49 +281,34 @@ pub async fn handle(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[instrument(
|
||||
name = "sql-over-http",
|
||||
skip_all,
|
||||
fields(
|
||||
pid = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty
|
||||
)
|
||||
)]
|
||||
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.with_label_values(&["http"])
|
||||
.guard();
|
||||
info!(
|
||||
protocol = ctx.protocol,
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
// TLS config should be there.
|
||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
|
||||
info!(
|
||||
user = conn_info.user_info.user.as_str(),
|
||||
project = conn_info.user_info.endpoint.as_str(),
|
||||
"credentials"
|
||||
);
|
||||
let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?;
|
||||
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
|
||||
let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
|
||||
let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
// Allow connection pooling only if explicitly requested
|
||||
// or if we have decided that http pool is no longer opt-in
|
||||
let allow_pool = !config.http_config.pool_options.opt_in
|
||||
|| headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|
||||
let allow_pool =
|
||||
!config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
// isolation level, read only and deferrable
|
||||
|
||||
@@ -367,12 +327,12 @@ async fn handle_inner(
|
||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let paused = ctx.latency_timer.pause();
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
None => MAX_REQUEST_SIZE + 1,
|
||||
};
|
||||
info!(request_content_length, "request size in bytes");
|
||||
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
|
||||
drop(paused);
|
||||
|
||||
// we don't have a streaming request support yet so this is to prevent OOM
|
||||
// from a malicious user sending an extremely large request body
|
||||
@@ -382,33 +342,13 @@ async fn handle_inner(
|
||||
));
|
||||
}
|
||||
|
||||
let fetch_and_process_request = async {
|
||||
let body = hyper::body::to_bytes(request.into_body())
|
||||
.await
|
||||
.map_err(anyhow::Error::from)?;
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
||||
};
|
||||
//
|
||||
// Read the query and query params from the request body
|
||||
//
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
|
||||
let authenticate_and_connect = async {
|
||||
let keys = backend.authenticate(ctx, &conn_info).await?;
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.latency_timer.success();
|
||||
Ok::<_, HttpConnError>(client)
|
||||
};
|
||||
|
||||
// Run both operations in parallel
|
||||
let (payload_result, auth_and_connect_result) =
|
||||
join!(fetch_and_process_request, authenticate_and_connect,);
|
||||
|
||||
// Handle the results
|
||||
let payload = payload_result?; // Handle errors appropriately
|
||||
let mut client = auth_and_connect_result?; // Handle errors appropriately
|
||||
let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
@@ -418,91 +358,86 @@ async fn handle_inner(
|
||||
// Now execute the query and return the result
|
||||
//
|
||||
let mut size = 0;
|
||||
let result = match payload {
|
||||
Payload::Single(stmt) => {
|
||||
let (status, results) =
|
||||
query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
client.discard();
|
||||
e
|
||||
})?;
|
||||
client.check_idle(status);
|
||||
results
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
info!("starting transaction");
|
||||
let (inner, mut discard) = client.inner();
|
||||
let mut builder = inner.build_transaction();
|
||||
if let Some(isolation_level) = txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
let result =
|
||||
match payload {
|
||||
Payload::Single(stmt) => {
|
||||
let (status, results) =
|
||||
query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
client.discard();
|
||||
e
|
||||
})?;
|
||||
client.check_idle(status);
|
||||
results
|
||||
}
|
||||
if txn_read_only {
|
||||
builder = builder.read_only(true);
|
||||
}
|
||||
if txn_deferrable {
|
||||
builder = builder.deferrable(true);
|
||||
}
|
||||
|
||||
let transaction = builder.start().await.map_err(|e| {
|
||||
// if we cannot start a transaction, we should return immediately
|
||||
// and not return to the pool. connection is clearly broken
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
|
||||
let results = match query_batch(
|
||||
&transaction,
|
||||
statements,
|
||||
&mut size,
|
||||
raw_output,
|
||||
default_array_mode,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(results) => {
|
||||
info!("commit");
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
results
|
||||
Payload::Batch(statements) => {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let mut builder = inner.build_transaction();
|
||||
if let Some(isolation_level) = txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
}
|
||||
Err(err) => {
|
||||
info!("rollback");
|
||||
let status = transaction.rollback().await.map_err(|e| {
|
||||
// if we cannot rollback - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
return Err(err);
|
||||
if txn_read_only {
|
||||
builder = builder.read_only(true);
|
||||
}
|
||||
if txn_deferrable {
|
||||
builder = builder.deferrable(true);
|
||||
}
|
||||
};
|
||||
|
||||
if txn_read_only {
|
||||
response = response.header(
|
||||
TXN_READ_ONLY.clone(),
|
||||
HeaderValue::try_from(txn_read_only.to_string())?,
|
||||
);
|
||||
}
|
||||
if txn_deferrable {
|
||||
response = response.header(
|
||||
TXN_DEFERRABLE.clone(),
|
||||
HeaderValue::try_from(txn_deferrable.to_string())?,
|
||||
);
|
||||
}
|
||||
if let Some(txn_isolation_level) = txn_isolation_level_raw {
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
json!({ "results": results })
|
||||
}
|
||||
};
|
||||
let transaction = builder.start().await.map_err(|e| {
|
||||
// if we cannot start a transaction, we should return immediately
|
||||
// and not return to the pool. connection is clearly broken
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
|
||||
let results =
|
||||
match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
|
||||
.await
|
||||
{
|
||||
Ok(results) => {
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
results
|
||||
}
|
||||
Err(err) => {
|
||||
let status = transaction.rollback().await.map_err(|e| {
|
||||
// if we cannot rollback - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
if txn_read_only {
|
||||
response = response.header(
|
||||
TXN_READ_ONLY.clone(),
|
||||
HeaderValue::try_from(txn_read_only.to_string())?,
|
||||
);
|
||||
}
|
||||
if txn_deferrable {
|
||||
response = response.header(
|
||||
TXN_DEFERRABLE.clone(),
|
||||
HeaderValue::try_from(txn_deferrable.to_string())?,
|
||||
);
|
||||
}
|
||||
if let Some(txn_isolation_level) = txn_isolation_level_raw {
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
json!({ "results": results })
|
||||
}
|
||||
};
|
||||
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
let metrics = client.metrics();
|
||||
|
||||
// how could this possibly fail
|
||||
@@ -545,12 +480,10 @@ async fn query_to_json<T: GenericClient>(
|
||||
data: QueryData,
|
||||
current_size: &mut usize,
|
||||
raw_output: bool,
|
||||
default_array_mode: bool,
|
||||
array_mode: bool,
|
||||
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
|
||||
info!("executing query");
|
||||
let query_params = data.params;
|
||||
let query_params = json_to_pg_text(data.params);
|
||||
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
|
||||
info!("finished executing query");
|
||||
|
||||
// Manually drain the stream into a vector to leave row_stream hanging
|
||||
// around to get a command tag. Also check that the response is not too
|
||||
@@ -585,13 +518,6 @@ async fn query_to_json<T: GenericClient>(
|
||||
}
|
||||
.and_then(|s| s.parse::<i64>().ok());
|
||||
|
||||
info!(
|
||||
rows = rows.len(),
|
||||
?ready,
|
||||
command_tag,
|
||||
"finished reading rows"
|
||||
);
|
||||
|
||||
let mut fields = vec![];
|
||||
let mut columns = vec![];
|
||||
|
||||
@@ -608,8 +534,6 @@ async fn query_to_json<T: GenericClient>(
|
||||
columns.push(client.get_type(c.type_oid()).await?);
|
||||
}
|
||||
|
||||
let array_mode = data.array_mode.unwrap_or(default_array_mode);
|
||||
|
||||
// convert rows to JSON
|
||||
let rows = rows
|
||||
.iter()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
cancellation::CancellationHandler,
|
||||
cancellation::CancelMap,
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
error::{io_error, ReportableError},
|
||||
error::io_error,
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
};
|
||||
@@ -131,41 +131,23 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
ctx: &mut RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
hostname: Option<String>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let res = handle_client(
|
||||
handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
cancellation_handler,
|
||||
ctx,
|
||||
cancel_map,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
ctx.log();
|
||||
Err(e.into())
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
p.proxy_pass().await
|
||||
}
|
||||
}
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
@@ -72,30 +73,6 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportedError {
|
||||
source: anyhow::Error,
|
||||
error_kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ReportedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.source.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ReportedError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
self.source.source()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReportedError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
self.error_kind
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
@@ -121,52 +98,24 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||
pub async fn throw_error_str<T>(
|
||||
&mut self,
|
||||
msg: &'static str,
|
||||
error_kind: ErrorKind,
|
||||
) -> Result<T, ReportedError> {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// already error case, ignore client IO error
|
||||
let _: Result<_, std::io::Error> = self
|
||||
.write_message(&BeMessage::ErrorResponse(msg, None))
|
||||
.await;
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(msg),
|
||||
error_kind,
|
||||
})
|
||||
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
||||
tracing::info!("forwarding error to user: {error}");
|
||||
self.write_message(&BeMessage::ErrorResponse(error, None))
|
||||
.await?;
|
||||
bail!(error)
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
|
||||
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%error,
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// already error case, ignore client IO error
|
||||
let _: Result<_, std::io::Error> = self
|
||||
.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.await;
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(error),
|
||||
error_kind,
|
||||
})
|
||||
tracing::info!("forwarding error to user: {msg}");
|
||||
self.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.await?;
|
||||
bail!(error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[toolchain]
|
||||
channel = "1.76.0"
|
||||
channel = "1.75.0"
|
||||
profile = "default"
|
||||
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
|
||||
# https://rust-lang.github.io/rustup/concepts/profiles.html
|
||||
|
||||
@@ -10,7 +10,6 @@ use utils::id::NodeId;
|
||||
|
||||
use std::cmp::min;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::num::NonZeroU32;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -547,10 +546,6 @@ pub async fn delete_timeline(ttid: &TenantTimelineId) -> Result<()> {
|
||||
let ttid_path = Utf8Path::new(&ttid.tenant_id.to_string()).join(ttid.timeline_id.to_string());
|
||||
let remote_path = RemotePath::new(&ttid_path)?;
|
||||
|
||||
// see DEFAULT_MAX_KEYS_PER_LIST_RESPONSE
|
||||
// const Option unwrap is not stable, otherwise it would be const.
|
||||
let batch_size: NonZeroU32 = NonZeroU32::new(1000).unwrap();
|
||||
|
||||
// A backoff::retry is used here for two reasons:
|
||||
// - To provide a backoff rather than busy-polling the API on errors
|
||||
// - To absorb transient 429/503 conditions without hitting our error
|
||||
@@ -562,26 +557,8 @@ pub async fn delete_timeline(ttid: &TenantTimelineId) -> Result<()> {
|
||||
let token = CancellationToken::new(); // not really used
|
||||
backoff::retry(
|
||||
|| async {
|
||||
// Do list-delete in batch_size batches to make progress even if there a lot of files.
|
||||
// Alternatively we could make list_files return iterator, but it is more complicated and
|
||||
// I'm not sure deleting while iterating is expected in s3.
|
||||
loop {
|
||||
let files = storage
|
||||
.list_files(Some(&remote_path), Some(batch_size))
|
||||
.await?;
|
||||
if files.is_empty() {
|
||||
return Ok(()); // done
|
||||
}
|
||||
// (at least) s3 results are sorted, so can log min/max:
|
||||
// "List results are always returned in UTF-8 binary order."
|
||||
info!(
|
||||
"deleting batch of {} WAL segments [{}-{}]",
|
||||
files.len(),
|
||||
files.first().unwrap().object_name().unwrap_or(""),
|
||||
files.last().unwrap().object_name().unwrap_or("")
|
||||
);
|
||||
storage.delete_objects(&files).await?;
|
||||
}
|
||||
let files = storage.list_files(Some(&remote_path)).await?;
|
||||
storage.delete_objects(&files).await
|
||||
},
|
||||
|_| false,
|
||||
3,
|
||||
@@ -617,7 +594,7 @@ pub async fn copy_s3_segments(
|
||||
|
||||
let remote_path = RemotePath::new(&relative_dst_path)?;
|
||||
|
||||
let files = storage.list_files(Some(&remote_path), None).await?;
|
||||
let files = storage.list_files(Some(&remote_path)).await?;
|
||||
let uploaded_segments = &files
|
||||
.iter()
|
||||
.filter_map(|file| file.object_name().map(ToOwned::to_owned))
|
||||
|
||||
@@ -23,7 +23,7 @@ from itertools import chain, product
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
|
||||
from urllib.parse import quote, urlparse
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import asyncpg
|
||||
import backoff
|
||||
@@ -1400,6 +1400,7 @@ class AbstractNeonCli(abc.ABC):
|
||||
|
||||
args = [bin_neon] + arguments
|
||||
log.info('Running command "{}"'.format(" ".join(args)))
|
||||
log.info(f'Running in "{self.env.repo_dir}"')
|
||||
|
||||
env_vars = os.environ.copy()
|
||||
env_vars["NEON_REPO_DIR"] = str(self.env.repo_dir)
|
||||
@@ -1815,7 +1816,6 @@ class NeonCli(AbstractNeonCli):
|
||||
endpoint_id: str,
|
||||
destroy=False,
|
||||
check_return_code=True,
|
||||
mode: Optional[str] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = [
|
||||
"endpoint",
|
||||
@@ -1823,8 +1823,6 @@ class NeonCli(AbstractNeonCli):
|
||||
]
|
||||
if destroy:
|
||||
args.append("--destroy")
|
||||
if mode is not None:
|
||||
args.append(f"--mode={mode}")
|
||||
if endpoint_id is not None:
|
||||
args.append(endpoint_id)
|
||||
|
||||
@@ -1951,15 +1949,6 @@ class NeonAttachmentService:
|
||||
|
||||
return headers
|
||||
|
||||
def ready(self) -> bool:
|
||||
resp = self.request("GET", f"{self.env.attachment_service_api}/ready")
|
||||
if resp.status_code == 503:
|
||||
return False
|
||||
elif resp.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected status {resp.status_code} from readiness endpoint")
|
||||
|
||||
def attach_hook_issue(
|
||||
self, tenant_shard_id: Union[TenantId, TenantShardId], pageserver_id: int
|
||||
) -> int:
|
||||
@@ -2458,7 +2447,6 @@ def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: PgVersion) -
|
||||
return PgBin(test_output_dir, pg_distrib_dir, pg_version)
|
||||
|
||||
|
||||
# TODO make port an optional argument
|
||||
class VanillaPostgres(PgProtocol):
|
||||
def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init: bool = True):
|
||||
super().__init__(host="localhost", port=port, dbname="postgres")
|
||||
@@ -2822,8 +2810,8 @@ class NeonProxy(PgProtocol):
|
||||
|
||||
def http_query(self, query, args, **kwargs):
|
||||
# TODO maybe use default values if not provided
|
||||
user = quote(kwargs["user"])
|
||||
password = quote(kwargs["password"])
|
||||
user = kwargs["user"]
|
||||
password = kwargs["password"]
|
||||
expected_code = kwargs.get("expected_code")
|
||||
|
||||
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
|
||||
@@ -3165,7 +3153,7 @@ class Endpoint(PgProtocol):
|
||||
with open(remote_extensions_spec_path, "w") as file:
|
||||
json.dump(spec, file, indent=4)
|
||||
|
||||
def stop(self, mode: str = "fast") -> "Endpoint":
|
||||
def stop(self) -> "Endpoint":
|
||||
"""
|
||||
Stop the Postgres instance if it's running.
|
||||
Returns self.
|
||||
@@ -3174,13 +3162,13 @@ class Endpoint(PgProtocol):
|
||||
if self.running:
|
||||
assert self.endpoint_id is not None
|
||||
self.env.neon_cli.endpoint_stop(
|
||||
self.endpoint_id, check_return_code=self.check_stop_result, mode=mode
|
||||
self.endpoint_id, check_return_code=self.check_stop_result
|
||||
)
|
||||
self.running = False
|
||||
|
||||
return self
|
||||
|
||||
def stop_and_destroy(self, mode: str = "immediate") -> "Endpoint":
|
||||
def stop_and_destroy(self) -> "Endpoint":
|
||||
"""
|
||||
Stop the Postgres instance, then destroy the endpoint.
|
||||
Returns self.
|
||||
@@ -3188,7 +3176,7 @@ class Endpoint(PgProtocol):
|
||||
|
||||
assert self.endpoint_id is not None
|
||||
self.env.neon_cli.endpoint_stop(
|
||||
self.endpoint_id, True, check_return_code=self.check_stop_result, mode=mode
|
||||
self.endpoint_id, True, check_return_code=self.check_stop_result
|
||||
)
|
||||
self.endpoint_id = None
|
||||
self.running = False
|
||||
@@ -3967,27 +3955,24 @@ def list_files_to_compare(pgdata_dir: Path) -> List[str]:
|
||||
|
||||
# pg is the existing and running compute node, that we want to compare with a basebackup
|
||||
def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint: Endpoint):
|
||||
pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version)
|
||||
|
||||
# Get the timeline ID. We need it for the 'basebackup' command
|
||||
timeline_id = TimelineId(endpoint.safe_psql("SHOW neon.timeline_id")[0][0])
|
||||
|
||||
# many tests already checkpoint, but do it just in case
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CHECKPOINT")
|
||||
|
||||
# wait for pageserver to catch up
|
||||
wait_for_last_flush_lsn(env, endpoint, endpoint.tenant_id, timeline_id)
|
||||
# stop postgres to ensure that files won't change
|
||||
endpoint.stop()
|
||||
|
||||
# Read the shutdown checkpoint's LSN
|
||||
pg_controldata_path = os.path.join(pg_bin.pg_bin_path, "pg_controldata")
|
||||
cmd = f"{pg_controldata_path} -D {endpoint.pgdata_dir}"
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, shell=True)
|
||||
checkpoint_lsn = re.findall(
|
||||
"Latest checkpoint location:\\s+([0-9A-F]+/[0-9A-F]+)", result.stdout
|
||||
)[0]
|
||||
log.debug(f"last checkpoint at {checkpoint_lsn}")
|
||||
|
||||
# Take a basebackup from pageserver
|
||||
restored_dir_path = env.repo_dir / f"{endpoint.endpoint_id}_restored_datadir"
|
||||
restored_dir_path.mkdir(exist_ok=True)
|
||||
|
||||
pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version)
|
||||
psql_path = os.path.join(pg_bin.pg_bin_path, "psql")
|
||||
|
||||
pageserver_id = env.attachment_service.locate(endpoint.tenant_id)[0]["node_id"]
|
||||
@@ -3995,7 +3980,7 @@ def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint
|
||||
{psql_path} \
|
||||
--no-psqlrc \
|
||||
postgres://localhost:{env.get_pageserver(pageserver_id).service_port.pg} \
|
||||
-c 'basebackup {endpoint.tenant_id} {timeline_id} {checkpoint_lsn}' \
|
||||
-c 'basebackup {endpoint.tenant_id} {timeline_id}' \
|
||||
| tar -x -C {restored_dir_path}
|
||||
"""
|
||||
|
||||
@@ -4069,7 +4054,7 @@ def logical_replication_sync(subscriber: VanillaPostgres, publisher: Endpoint) -
|
||||
|
||||
|
||||
def tenant_get_shards(
|
||||
env: NeonEnv, tenant_id: TenantId, pageserver_id: Optional[int] = None
|
||||
env: NeonEnv, tenant_id: TenantId, pageserver_id: Optional[int]
|
||||
) -> list[tuple[TenantShardId, NeonPageserver]]:
|
||||
"""
|
||||
Helper for when you want to talk to one or more pageservers, and the
|
||||
|
||||
@@ -45,6 +45,7 @@ def test_ancestor_branch(neon_env_builder: NeonEnvBuilder):
|
||||
# Create branch1.
|
||||
env.neon_cli.create_branch("branch1", "main", tenant_id=tenant, ancestor_start_lsn=lsn_100)
|
||||
endpoint_branch1 = env.endpoints.create_start("branch1", tenant_id=tenant)
|
||||
log.info("postgres is running on 'branch1' branch")
|
||||
|
||||
branch1_cur = endpoint_branch1.connect().cursor()
|
||||
branch1_timeline = TimelineId(query_scalar(branch1_cur, "SHOW neon.timeline_id"))
|
||||
@@ -67,6 +68,7 @@ def test_ancestor_branch(neon_env_builder: NeonEnvBuilder):
|
||||
# Create branch2.
|
||||
env.neon_cli.create_branch("branch2", "branch1", tenant_id=tenant, ancestor_start_lsn=lsn_200)
|
||||
endpoint_branch2 = env.endpoints.create_start("branch2", tenant_id=tenant)
|
||||
log.info("postgres is running on 'branch2' branch")
|
||||
branch2_cur = endpoint_branch2.connect().cursor()
|
||||
|
||||
branch2_timeline = TimelineId(query_scalar(branch2_cur, "SHOW neon.timeline_id"))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user