Compare commits

..

1 Commits

Author SHA1 Message Date
Konstantin Knizhnik
13b1699518 Specify statement_timeout within safe_pgsql command 2024-02-06 21:51:12 +02:00
167 changed files with 1982 additions and 5477 deletions

View File

@@ -179,12 +179,6 @@ runs:
aws s3 rm "s3://${BUCKET}/${LOCK_FILE}"
fi
- name: Cache poetry deps
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v2-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
- name: Store Allure test stat in the DB (new)
if: ${{ !cancelled() && inputs.store-test-results-into-db == 'true' }}
shell: bash -euxo pipefail {0}

View File

@@ -86,10 +86,11 @@ runs:
fetch-depth: 1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v2-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
key: v1-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
shell: bash -euxo pipefail {0}

View File

@@ -93,7 +93,6 @@ jobs:
--body-file "body.md" \
--head "${BRANCH}" \
--base "main" \
--label "run-e2e-tests-in-draft" \
--draft
fi

View File

@@ -22,7 +22,7 @@ env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
# A concurrency group that we use for e2e-tests runs, matches `concurrency.group` above with `github.repository` as a prefix
E2E_CONCURRENCY_GROUP: ${{ github.repository }}-e2e-tests-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
E2E_CONCURRENCY_GROUP: ${{ github.repository }}-${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
jobs:
check-permissions:
@@ -112,10 +112,11 @@ jobs:
fetch-depth: 1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v2-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
key: v1-codestyle-python-deps-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
run: ./scripts/pysync
@@ -692,10 +693,50 @@ jobs:
})
trigger-e2e-tests:
if: ${{ !github.event.pull_request.draft || contains( github.event.pull_request.labels.*.name, 'run-e2e-tests-in-draft') || github.ref_name == 'main' || github.ref_name == 'release' }}
needs: [ check-permissions, promote-images, tag ]
uses: ./.github/workflows/trigger-e2e-tests.yml
secrets: inherit
runs-on: [ self-hosted, gen3, small ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
options: --init
steps:
- name: Set PR's status to pending and request a remote CI test
run: |
# For pull requests, GH Actions set "github.sha" variable to point at a fake merge commit
# but we need to use a real sha of a latest commit in the PR's branch for the e2e job,
# to place a job run status update later.
COMMIT_SHA=${{ github.event.pull_request.head.sha }}
# For non-PR kinds of runs, the above will produce an empty variable, pick the original sha value for those
COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}}
REMOTE_REPO="${{ github.repository_owner }}/cloud"
curl -f -X POST \
https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"state\": \"pending\",
\"context\": \"neon-cloud-e2e\",
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
}"
curl -f -X POST \
https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"ref\": \"main\",
\"inputs\": {
\"ci_job_name\": \"neon-cloud-e2e\",
\"commit_hash\": \"$COMMIT_SHA\",
\"remote_repo\": \"${{ github.repository }}\",
\"storage_image_tag\": \"${{ needs.tag.outputs.build-tag }}\",
\"compute_image_tag\": \"${{ needs.tag.outputs.build-tag }}\",
\"concurrency_group\": \"${{ env.E2E_CONCURRENCY_GROUP }}\"
}
}"
neon-image:
needs: [ check-permissions, build-buildtools-image, tag ]

View File

@@ -38,10 +38,11 @@ jobs:
uses: snok/install-poetry@v1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v2-${{ runner.os }}-python-deps-ubunutu-latest-${{ hashFiles('poetry.lock') }}
key: v1-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
shell: bash -euxo pipefail {0}

View File

@@ -1,118 +0,0 @@
name: Trigger E2E Tests
on:
pull_request:
types:
- ready_for_review
workflow_call:
defaults:
run:
shell: bash -euxo pipefail {0}
env:
# A concurrency group that we use for e2e-tests runs, matches `concurrency.group` above with `github.repository` as a prefix
E2E_CONCURRENCY_GROUP: ${{ github.repository }}-e2e-tests-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
jobs:
cancel-previous-e2e-tests:
if: github.event_name == 'pull_request'
runs-on: ubuntu-latest
steps:
- name: Cancel previous e2e-tests runs for this PR
env:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
run: |
gh workflow --repo neondatabase/cloud \
run cancel-previous-in-concurrency-group.yml \
--field concurrency_group="${{ env.E2E_CONCURRENCY_GROUP }}"
tag:
runs-on: [ ubuntu-latest ]
outputs:
build-tag: ${{ steps.build-tag.outputs.tag }}
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get build tag
env:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
CURRENT_BRANCH: ${{ github.head_ref || github.ref_name }}
CURRENT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
echo "tag=$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
echo "tag=release-$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT
else
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
BUILD_AND_TEST_RUN_ID=$(gh run list -b $CURRENT_BRANCH -c $CURRENT_SHA -w 'Build and Test' -L 1 --json databaseId --jq '.[].databaseId')
echo "tag=$BUILD_AND_TEST_RUN_ID" | tee -a $GITHUB_OUTPUT
fi
id: build-tag
trigger-e2e-tests:
needs: [ tag ]
runs-on: [ self-hosted, gen3, small ]
env:
TAG: ${{ needs.tag.outputs.build-tag }}
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
options: --init
steps:
- name: check if ecr image are present
run: |
for REPO in neon compute-tools compute-node-v14 vm-compute-node-v14 compute-node-v15 vm-compute-node-v15 compute-node-v16 vm-compute-node-v16; do
OUTPUT=$(aws ecr describe-images --repository-name ${REPO} --region eu-central-1 --query "imageDetails[?imageTags[?contains(@, '${TAG}')]]" --output text)
if [ "$OUTPUT" == "" ]; then
echo "$REPO with image tag $TAG not found" >> $GITHUB_OUTPUT
exit 1
fi
done
- name: Set PR's status to pending and request a remote CI test
run: |
# For pull requests, GH Actions set "github.sha" variable to point at a fake merge commit
# but we need to use a real sha of a latest commit in the PR's branch for the e2e job,
# to place a job run status update later.
COMMIT_SHA=${{ github.event.pull_request.head.sha }}
# For non-PR kinds of runs, the above will produce an empty variable, pick the original sha value for those
COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}}
REMOTE_REPO="${{ github.repository_owner }}/cloud"
curl -f -X POST \
https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"state\": \"pending\",
\"context\": \"neon-cloud-e2e\",
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
}"
curl -f -X POST \
https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"ref\": \"main\",
\"inputs\": {
\"ci_job_name\": \"neon-cloud-e2e\",
\"commit_hash\": \"$COMMIT_SHA\",
\"remote_repo\": \"${{ github.repository }}\",
\"storage_image_tag\": \"${TAG}\",
\"compute_image_tag\": \"${TAG}\",
\"concurrency_group\": \"${{ env.E2E_CONCURRENCY_GROUP }}\"
}
}"

View File

@@ -54,9 +54,6 @@ _An instruction for maintainers_
- If and only if it looks **safe** (i.e. it doesn't contain any malicious code which could expose secrets or harm the CI), then:
- Press the "Approve and run" button in GitHub UI
- Add the `approved-for-ci-run` label to the PR
- Currently draft PR will skip e2e test (only for internal contributors). After turning the PR 'Ready to Review' CI will trigger e2e test
- Add `run-e2e-tests-in-draft` label to run e2e test in draft PR (override above behaviour)
- The `approved-for-ci-run` workflow will add `run-e2e-tests-in-draft` automatically to run e2e test for external contributors
Repeat all steps after any change to the PR.
- When the changes are ready to get merged — merge the original PR (not the internal one)

42
Cargo.lock generated
View File

@@ -289,7 +289,6 @@ dependencies = [
"pageserver_api",
"pageserver_client",
"postgres_connection",
"r2d2",
"reqwest",
"serde",
"serde_json",
@@ -1329,6 +1328,8 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"diesel",
"diesel_migrations",
"futures",
"git-version",
"hex",
@@ -1650,7 +1651,6 @@ dependencies = [
"diesel_derives",
"itoa",
"pq-sys",
"r2d2",
"serde_json",
]
@@ -2247,11 +2247,11 @@ dependencies = [
[[package]]
name = "hashlink"
version = "0.8.4"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
dependencies = [
"hashbrown 0.14.0",
"hashbrown 0.13.2",
]
[[package]]
@@ -2867,7 +2867,6 @@ dependencies = [
"chrono",
"libc",
"once_cell",
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr",
@@ -3936,7 +3935,6 @@ dependencies = [
"pin-project-lite",
"postgres-protocol",
"rand 0.8.5",
"serde",
"thiserror",
"tokio",
"tracing",
@@ -3986,8 +3984,6 @@ checksum = "b1de8dacb0873f77e6aefc6d71e044761fcc68060290f5b1089fcdf84626bb69"
dependencies = [
"bitflags 1.3.2",
"byteorder",
"chrono",
"flate2",
"hex",
"lazy_static",
"rustix 0.36.16",
@@ -4078,7 +4074,6 @@ dependencies = [
"clap",
"consumption_metrics",
"dashmap",
"env_logger",
"futures",
"git-version",
"hashbrown 0.13.2",
@@ -4126,7 +4121,6 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"smallvec",
"smol_str",
"socket2 0.5.5",
"sync_wrapper",
@@ -4145,7 +4139,6 @@ dependencies = [
"tracing-subscriber",
"tracing-utils",
"url",
"urlencoding",
"utils",
"uuid",
"walkdir",
@@ -4173,17 +4166,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r2d2"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93"
dependencies = [
"log",
"parking_lot 0.12.1",
"scheduled-thread-pool",
]
[[package]]
name = "rand"
version = "0.7.3"
@@ -4897,15 +4879,6 @@ dependencies = [
"windows-sys 0.42.0",
]
[[package]]
name = "scheduled-thread-pool"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19"
dependencies = [
"parking_lot 0.12.1",
]
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -5741,7 +5714,7 @@ dependencies = [
[[package]]
name = "tokio-epoll-uring"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d"
dependencies = [
"futures",
"nix 0.26.4",
@@ -6266,7 +6239,7 @@ dependencies = [
[[package]]
name = "uring-common"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d"
dependencies = [
"io-uring",
"libc",
@@ -6833,6 +6806,7 @@ dependencies = [
"clap",
"clap_builder",
"crossbeam-utils",
"diesel",
"either",
"fail",
"futures-channel",

View File

@@ -80,7 +80,7 @@ futures-core = "0.3"
futures-util = "0.3"
git-version = "0.3"
hashbrown = "0.13"
hashlink = "0.8.4"
hashlink = "0.8.1"
hdrhistogram = "7.5.2"
hex = "0.4"
hex-literal = "0.4"
@@ -113,7 +113,6 @@ parquet = { version = "49.0.0", default-features = false, features = ["zstd"] }
parquet_derive = "49.0.0"
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
pin-project-lite = "0.2"
procfs = "0.14"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"
rand = "0.8"
@@ -171,7 +170,6 @@ tracing-opentelemetry = "0.20.0"
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
twox-hash = { version = "1.6.3", default-features = false }
url = "2.2"
urlencoding = "2.1"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
webpki-roots = "0.25"

View File

@@ -100,11 +100,6 @@ RUN mkdir -p /data/.neon/ && chown -R neon:neon /data/.neon/ \
-c "listen_pg_addr='0.0.0.0:6400'" \
-c "listen_http_addr='0.0.0.0:9898'"
# When running a binary that links with libpq, default to using our most recent postgres version. Binaries
# that want a particular postgres version will select it explicitly: this is just a default.
ENV LD_LIBRARY_PATH /usr/local/v16/lib
VOLUME ["/data"]
USER neon
EXPOSE 6400

View File

@@ -111,7 +111,7 @@ USER nonroot:nonroot
WORKDIR /home/nonroot
# Python
ENV PYTHON_VERSION=3.9.18 \
ENV PYTHON_VERSION=3.9.2 \
PYENV_ROOT=/home/nonroot/.pyenv \
PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH
RUN set -e \
@@ -135,7 +135,7 @@ WORKDIR /home/nonroot
# Rust
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
ENV RUSTC_VERSION=1.76.0
ENV RUSTC_VERSION=1.75.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \

View File

@@ -639,8 +639,8 @@ FROM build-deps AS pg-anon-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PATH "/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/neondatabase/postgresql_anonymizer/archive/refs/tags/neon_1.1.1.tar.gz -O pg_anon.tar.gz && \
echo "321ea8d5c1648880aafde850a2c576e4a9e7b9933a34ce272efc839328999fa9 pg_anon.tar.gz" | sha256sum --check && \
RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/1.1.0/postgresql_anonymizer-1.1.0.tar.gz -O pg_anon.tar.gz && \
echo "08b09d2ff9b962f96c60db7e6f8e79cf7253eb8772516998fc35ece08633d3ad pg_anon.tar.gz" | sha256sum --check && \
mkdir pg_anon-src && cd pg_anon-src && tar xvzf ../pg_anon.tar.gz --strip-components=1 -C . && \
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -809,7 +809,6 @@ COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-semver-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
COPY --from=pg-anon-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \

2
NOTICE
View File

@@ -1,5 +1,5 @@
Neon
Copyright 2022 - 2024 Neon Inc.
Copyright 2022 Neon Inc.
The PostgreSQL submodules in vendor/ are licensed under the PostgreSQL license.
See vendor/postgres-vX/COPYRIGHT for details.

View File

@@ -765,12 +765,7 @@ impl ComputeNode {
handle_roles(spec, &mut client)?;
handle_databases(spec, &mut client)?;
handle_role_deletions(spec, connstr.as_str(), &mut client)?;
handle_grants(
spec,
&mut client,
connstr.as_str(),
self.has_feature(ComputeFeature::AnonExtension),
)?;
handle_grants(spec, &mut client, connstr.as_str())?;
handle_extensions(spec, &mut client)?;
handle_extension_neon(&mut client)?;
create_availability_check_data(&mut client)?;
@@ -778,11 +773,12 @@ impl ComputeNode {
// 'Close' connection
drop(client);
// Run migrations separately to not hold up cold starts
thread::spawn(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_migrations(&mut client)
});
if self.has_feature(ComputeFeature::Migrations) {
thread::spawn(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_migrations(&mut client)
});
}
Ok(())
}
@@ -844,12 +840,7 @@ impl ComputeNode {
handle_roles(&spec, &mut client)?;
handle_databases(&spec, &mut client)?;
handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?;
handle_grants(
&spec,
&mut client,
self.connstr.as_str(),
self.has_feature(ComputeFeature::AnonExtension),
)?;
handle_grants(&spec, &mut client, self.connstr.as_str())?;
handle_extensions(&spec, &mut client)?;
handle_extension_neon(&mut client)?;
// We can skip handle_migrations here because a new migration can only appear

View File

@@ -264,10 +264,9 @@ pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
// case we miss some events for some reason. Not strictly necessary, but
// better safe than sorry.
let (tx, rx) = std::sync::mpsc::channel();
let watcher_res = notify::recommended_watcher(move |res| {
let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
let _ = tx.send(res);
});
let (mut watcher, rx): (Box<dyn Watcher>, _) = match watcher_res {
}) {
Ok(watcher) => (Box::new(watcher), rx),
Err(e) => {
match e.kind {

View File

@@ -581,12 +581,7 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
/// Grant CREATE ON DATABASE to the database owner and do some other alters and grants
/// to allow users creating trusted extensions and re-creating `public` schema, for example.
#[instrument(skip_all)]
pub fn handle_grants(
spec: &ComputeSpec,
client: &mut Client,
connstr: &str,
enable_anon_extension: bool,
) -> Result<()> {
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> {
info!("modifying database permissions");
let existing_dbs = get_existing_dbs(client)?;
@@ -683,11 +678,6 @@ pub fn handle_grants(
inlinify(&grant_query)
);
db_client.simple_query(&grant_query)?;
// it is important to run this after all grants
if enable_anon_extension {
handle_extension_anon(spec, &db.owner, &mut db_client, false)?;
}
}
Ok(())
@@ -776,7 +766,6 @@ BEGIN
END IF;
END
$$;"#,
"GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION",
];
let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
@@ -820,125 +809,5 @@ $$;"#,
"Ran {} migrations",
(migrations.len() - starting_migration_id)
);
Ok(())
}
/// Connect to the database as superuser and pre-create anon extension
/// if it is present in shared_preload_libraries
#[instrument(skip_all)]
pub fn handle_extension_anon(
spec: &ComputeSpec,
db_owner: &str,
db_client: &mut Client,
grants_only: bool,
) -> Result<()> {
info!("handle extension anon");
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
if libs.contains("anon") {
if !grants_only {
// check if extension is already initialized using anon.is_initialized()
let query = "SELECT anon.is_initialized()";
match db_client.query(query, &[]) {
Ok(rows) => {
if !rows.is_empty() {
let is_initialized: bool = rows[0].get(0);
if is_initialized {
info!("anon extension is already initialized");
return Ok(());
}
}
}
Err(e) => {
warn!(
"anon extension is_installed check failed with expected error: {}",
e
);
}
};
// Create anon extension if this compute needs it
// Users cannot create it themselves, because superuser is required.
let mut query = "CREATE EXTENSION IF NOT EXISTS anon CASCADE";
info!("creating anon extension with query: {}", query);
match db_client.query(query, &[]) {
Ok(_) => {}
Err(e) => {
error!("anon extension creation failed with error: {}", e);
return Ok(());
}
}
// check that extension is installed
query = "SELECT extname FROM pg_extension WHERE extname = 'anon'";
let rows = db_client.query(query, &[])?;
if rows.is_empty() {
error!("anon extension is not installed");
return Ok(());
}
// Initialize anon extension
// This also requires superuser privileges, so users cannot do it themselves.
query = "SELECT anon.init()";
match db_client.query(query, &[]) {
Ok(_) => {}
Err(e) => {
error!("anon.init() failed with error: {}", e);
return Ok(());
}
}
}
// check that extension is installed, if not bail early
let query = "SELECT extname FROM pg_extension WHERE extname = 'anon'";
match db_client.query(query, &[]) {
Ok(rows) => {
if rows.is_empty() {
error!("anon extension is not installed");
return Ok(());
}
}
Err(e) => {
error!("anon extension check failed with error: {}", e);
return Ok(());
}
};
let query = format!("GRANT ALL ON SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
// Grant permissions to db_owner to use anon extension functions
let query = format!("GRANT ALL ON ALL FUNCTIONS IN SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
// This is needed, because some functions are defined as SECURITY DEFINER.
// In Postgres SECURITY DEFINER functions are executed with the privileges
// of the owner.
// In anon extension this it is needed to access some GUCs, which are only accessible to
// superuser. But we've patched postgres to allow db_owner to access them as well.
// So we need to change owner of these functions to db_owner.
let query = format!("
SELECT 'ALTER FUNCTION '||nsp.nspname||'.'||p.proname||'('||pg_get_function_identity_arguments(p.oid)||') OWNER TO {};'
from pg_proc p
join pg_namespace nsp ON p.pronamespace = nsp.oid
where nsp.nspname = 'anon';", db_owner);
info!("change anon extension functions owner to db owner");
db_client.simple_query(&query)?;
// affects views as well
let query = format!("GRANT ALL ON ALL TABLES IN SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
let query = format!("GRANT ALL ON ALL SEQUENCES IN SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
}
}
Ok(())
}

View File

@@ -10,6 +10,8 @@ async-trait.workspace = true
camino.workspace = true
clap.workspace = true
comfy-table.workspace = true
diesel = { version = "2.1.4", features = ["postgres"]}
diesel_migrations = { version = "2.1.0", features = ["postgres"]}
futures.workspace = true
git-version.workspace = true
nix.workspace = true

View File

@@ -24,9 +24,8 @@ tokio.workspace = true
tokio-util.workspace = true
tracing.workspace = true
diesel = { version = "2.1.4", features = ["serde_json", "postgres", "r2d2"] }
diesel = { version = "2.1.4", features = ["serde_json", "postgres"] }
diesel_migrations = { version = "2.1.0" }
r2d2 = { version = "0.8.10" }
utils = { path = "../../libs/utils/" }
metrics = { path = "../../libs/metrics/" }

View File

@@ -7,7 +7,6 @@ CREATE TABLE tenant_shards (
generation INTEGER NOT NULL,
generation_pageserver BIGINT NOT NULL,
placement_policy VARCHAR NOT NULL,
splitting SMALLINT NOT NULL,
-- config is JSON encoded, opaque to the database.
config TEXT NOT NULL
);

View File

@@ -170,7 +170,7 @@ impl ComputeHook {
reconfigure_request: &ComputeHookNotifyRequest,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let req = client.request(Method::PUT, url);
let req = client.request(Method::POST, url);
let req = if let Some(value) = &self.authorization_header {
req.header(reqwest::header::AUTHORIZATION, value)
} else {
@@ -240,7 +240,7 @@ impl ComputeHook {
let client = reqwest::Client::new();
backoff::retry(
|| self.do_notify_iteration(&client, url, &reconfigure_request, cancel),
|e| matches!(e, NotifyError::Fatal(_) | NotifyError::Unexpected(_)),
|e| matches!(e, NotifyError::Fatal(_)),
3,
10,
"Send compute notification",

View File

@@ -3,8 +3,7 @@ use crate::service::{Service, STARTUP_RECONCILE_TIMEOUT};
use hyper::{Body, Request, Response};
use hyper::{StatusCode, Uri};
use pageserver_api::models::{
TenantCreateRequest, TenantLocationConfigRequest, TenantShardSplitRequest,
TimelineCreateRequest,
TenantCreateRequest, TenantLocationConfigRequest, TimelineCreateRequest,
};
use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api;
@@ -42,7 +41,7 @@ pub struct HttpState {
impl HttpState {
pub fn new(service: Arc<crate::service::Service>, auth: Option<Arc<SwappableJwtAuth>>) -> Self {
let allowlist_routes = ["/status", "/ready", "/metrics"]
let allowlist_routes = ["/status"]
.iter()
.map(|v| v.parse().unwrap())
.collect::<Vec<_>>();
@@ -280,12 +279,6 @@ async fn handle_node_list(req: Request<Body>) -> Result<Response<Body>, ApiError
json_response(StatusCode::OK, state.service.node_list().await?)
}
async fn handle_node_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?;
json_response(StatusCode::OK, state.service.node_drop(node_id).await?)
}
async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let node_id: NodeId = parse_request_param(&req, "node_id")?;
let config_req = json_request::<NodeConfigureRequest>(&mut req).await?;
@@ -299,19 +292,6 @@ async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>,
json_response(StatusCode::OK, state.service.node_configure(config_req)?)
}
async fn handle_tenant_shard_split(
service: Arc<Service>,
mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let split_req = json_request::<TenantShardSplitRequest>(&mut req).await?;
json_response(
StatusCode::OK,
service.tenant_shard_split(tenant_id, split_req).await?,
)
}
async fn handle_tenant_shard_migrate(
service: Arc<Service>,
mut req: Request<Body>,
@@ -326,29 +306,11 @@ async fn handle_tenant_shard_migrate(
)
}
async fn handle_tenant_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let state = get_state(&req);
json_response(StatusCode::OK, state.service.tenant_drop(tenant_id).await?)
}
/// Status endpoint is just used for checking that our HTTP listener is up
async fn handle_status(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, ())
}
/// Readiness endpoint indicates when we're done doing startup I/O (e.g. reconciling
/// with remote pageserver nodes). This is intended for use as a kubernetes readiness probe.
async fn handle_ready(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let state = get_state(&req);
if state.service.startup_complete.is_ready() {
json_response(StatusCode::OK, ())
} else {
json_response(StatusCode::SERVICE_UNAVAILABLE, ())
}
}
impl From<ReconcileError> for ApiError {
fn from(value: ReconcileError) -> Self {
ApiError::Conflict(format!("Reconciliation error: {}", value))
@@ -404,7 +366,6 @@ pub fn make_router(
.data(Arc::new(HttpState::new(service, auth)))
// Non-prefixed generic endpoints (status, metrics)
.get("/status", |r| request_span(r, handle_status))
.get("/ready", |r| request_span(r, handle_ready))
// Upcalls for the pageserver: point the pageserver's `control_plane_api` config to this prefix
.post("/upcall/v1/re-attach", |r| {
request_span(r, handle_re_attach)
@@ -415,12 +376,6 @@ pub fn make_router(
request_span(r, handle_attach_hook)
})
.post("/debug/v1/inspect", |r| request_span(r, handle_inspect))
.post("/debug/v1/tenant/:tenant_id/drop", |r| {
request_span(r, handle_tenant_drop)
})
.post("/debug/v1/node/:node_id/drop", |r| {
request_span(r, handle_node_drop)
})
.get("/control/v1/tenant/:tenant_id/locate", |r| {
tenant_service_handler(r, handle_tenant_locate)
})
@@ -436,9 +391,6 @@ pub fn make_router(
.put("/control/v1/tenant/:tenant_shard_id/migrate", |r| {
tenant_service_handler(r, handle_tenant_shard_migrate)
})
.put("/control/v1/tenant/:tenant_id/shard_split", |r| {
tenant_service_handler(r, handle_tenant_shard_split)
})
// Tenant operations
// The ^/v1/ endpoints act as a "Virtual Pageserver", enabling shard-naive clients to call into
// this service to manage tenants that actually consist of many tenant shards, as if they are a single entity.

View File

@@ -170,7 +170,6 @@ impl Secrets {
}
}
/// Execute the diesel migrations that are built into this binary
async fn migration_run(database_url: &str) -> anyhow::Result<()> {
use diesel::PgConnection;
use diesel_migrations::{HarnessWithOutput, MigrationHarness};
@@ -184,18 +183,8 @@ async fn migration_run(database_url: &str) -> anyhow::Result<()> {
Ok(())
}
fn main() -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread()
// We use spawn_blocking for database operations, so require approximately
// as many blocking threads as we will open database connections.
.max_blocking_threads(Persistence::MAX_CONNECTIONS as usize)
.enable_all()
.build()
.unwrap()
.block_on(async_main())
}
async fn async_main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate()));
logging::init(

View File

@@ -1,9 +1,6 @@
pub(crate) mod split_state;
use std::collections::HashMap;
use std::str::FromStr;
use std::time::Duration;
use self::split_state::SplitState;
use camino::Utf8Path;
use camino::Utf8PathBuf;
use control_plane::attachment_service::{NodeAvailability, NodeSchedulingPolicy};
@@ -47,7 +44,7 @@ use crate::PlacementPolicy;
/// updated, and reads of nodes are always from memory, not the database. We only require that
/// we can UPDATE a node's scheduling mode reasonably quickly to mark a bad node offline.
pub struct Persistence {
connection_pool: diesel::r2d2::Pool<diesel::r2d2::ConnectionManager<PgConnection>>,
database_url: String,
// In test environments, we support loading+saving a JSON file. This is temporary, for the benefit of
// test_compatibility.py, so that we don't have to commit to making the database contents fully backward/forward
@@ -67,8 +64,6 @@ pub(crate) enum DatabaseError {
Query(#[from] diesel::result::Error),
#[error(transparent)]
Connection(#[from] diesel::result::ConnectionError),
#[error(transparent)]
ConnectionPool(#[from] r2d2::Error),
#[error("Logical error: {0}")]
Logical(String),
}
@@ -76,31 +71,9 @@ pub(crate) enum DatabaseError {
pub(crate) type DatabaseResult<T> = Result<T, DatabaseError>;
impl Persistence {
// The default postgres connection limit is 100. We use up to 99, to leave one free for a human admin under
// normal circumstances. This assumes we have exclusive use of the database cluster to which we connect.
pub const MAX_CONNECTIONS: u32 = 99;
// We don't want to keep a lot of connections alive: close them down promptly if they aren't being used.
const IDLE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(60);
pub fn new(database_url: String, json_path: Option<Utf8PathBuf>) -> Self {
let manager = diesel::r2d2::ConnectionManager::<PgConnection>::new(database_url);
// We will use a connection pool: this is primarily to _limit_ our connection count, rather than to optimize time
// to execute queries (database queries are not generally on latency-sensitive paths).
let connection_pool = diesel::r2d2::Pool::builder()
.max_size(Self::MAX_CONNECTIONS)
.max_lifetime(Some(Self::MAX_CONNECTION_LIFETIME))
.idle_timeout(Some(Self::IDLE_CONNECTION_TIMEOUT))
// Always keep at least one connection ready to go
.min_idle(Some(1))
.test_on_check_out(true)
.build(manager)
.expect("Could not build connection pool");
Self {
connection_pool,
database_url,
json_path,
}
}
@@ -111,10 +84,14 @@ impl Persistence {
F: Fn(&mut PgConnection) -> DatabaseResult<R> + Send + 'static,
R: Send + 'static,
{
let mut conn = self.connection_pool.get()?;
tokio::task::spawn_blocking(move || -> DatabaseResult<R> { func(&mut conn) })
.await
.expect("Task panic")
let database_url = self.database_url.clone();
tokio::task::spawn_blocking(move || -> DatabaseResult<R> {
// TODO: connection pooling, such as via diesel::r2d2
let mut conn = PgConnection::establish(&database_url)?;
func(&mut conn)
})
.await
.expect("Task panic")
}
/// When a node is first registered, persist it before using it for anything
@@ -260,6 +237,7 @@ impl Persistence {
/// Ordering: call this _after_ deleting the tenant on pageservers, but _before_ dropping state for
/// the tenant from memory on this server.
#[allow(unused)]
pub(crate) async fn delete_tenant(&self, del_tenant_id: TenantId) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
@@ -272,18 +250,6 @@ impl Persistence {
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)?;
Ok(())
})
.await
}
/// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient
/// batched increment of the generations of all tenants whose generation_pageserver is equal to
/// the node that called /re-attach.
@@ -376,107 +342,19 @@ impl Persistence {
Ok(())
}
// When we start shard splitting, we must durably mark the tenant so that
// on restart, we know that we must go through recovery.
//
// We create the child shards here, so that they will be available for increment_generation calls
// if some pageserver holding a child shard needs to restart before the overall tenant split is complete.
// TODO: when we start shard splitting, we must durably mark the tenant so that
// on restart, we know that we must go through recovery (list shards that exist
// and pick up where we left off and/or revert to parent shards).
#[allow(dead_code)]
pub(crate) async fn begin_shard_split(
&self,
old_shard_count: ShardCount,
split_tenant_id: TenantId,
parent_to_children: Vec<(TenantShardId, Vec<TenantShardPersistence>)>,
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> DatabaseResult<()> {
// Mark parent shards as splitting
let expect_parent_records = std::cmp::max(1, old_shard_count.0);
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.0 as i32))
.set((splitting.eq(1),))
.execute(conn)?;
if u8::try_from(updated)
.map_err(|_| DatabaseError::Logical(
format!("Overflow existing shard count {} while splitting", updated))
)? != expect_parent_records {
// Perhaps a deletion or another split raced with this attempt to split, mutating
// the parent shards that we intend to split. In this case the split request should fail.
return Err(DatabaseError::Logical(
format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {expect_parent_records})")
));
}
// FIXME: spurious clone to sidestep closure move rules
let parent_to_children = parent_to_children.clone();
// Insert child shards
for (parent_shard_id, children) in parent_to_children {
let mut parent = crate::schema::tenant_shards::table
.filter(tenant_id.eq(parent_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(parent_shard_id.shard_number.0 as i32))
.filter(shard_count.eq(parent_shard_id.shard_count.0 as i32))
.load::<TenantShardPersistence>(conn)?;
let parent = if parent.len() != 1 {
return Err(DatabaseError::Logical(format!(
"Parent shard {parent_shard_id} not found"
)));
} else {
parent.pop().unwrap()
};
for mut shard in children {
// Carry the parent's generation into the child
shard.generation = parent.generation;
debug_assert!(shard.splitting == SplitState::Splitting);
diesel::insert_into(tenant_shards)
.values(shard)
.execute(conn)?;
}
}
Ok(())
})?;
Ok(())
})
.await
pub(crate) async fn begin_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
todo!();
}
// When we finish shard splitting, we must atomically clean up the old shards
// TODO: when we finish shard splitting, we must atomically clean up the old shards
// and insert the new shards, and clear the splitting marker.
#[allow(dead_code)]
pub(crate) async fn complete_shard_split(
&self,
split_tenant_id: TenantId,
old_shard_count: ShardCount,
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> QueryResult<()> {
// Drop parent shards
diesel::delete(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.0 as i32))
.execute(conn)?;
// Clear sharding flag
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.set((splitting.eq(0),))
.execute(conn)?;
debug_assert!(updated > 0);
Ok(())
})?;
Ok(())
})
.await
pub(crate) async fn complete_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
todo!();
}
}
@@ -504,8 +382,6 @@ pub(crate) struct TenantShardPersistence {
#[serde(default)]
pub(crate) placement_policy: String,
#[serde(default)]
pub(crate) splitting: SplitState,
#[serde(default)]
pub(crate) config: String,
}

View File

@@ -1,46 +0,0 @@
use diesel::pg::{Pg, PgValue};
use diesel::{
deserialize::FromSql, deserialize::FromSqlRow, expression::AsExpression, serialize::ToSql,
sql_types::Int2,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, FromSqlRow, AsExpression)]
#[diesel(sql_type = SplitStateSQLRepr)]
#[derive(Deserialize, Serialize)]
pub enum SplitState {
Idle = 0,
Splitting = 1,
}
impl Default for SplitState {
fn default() -> Self {
Self::Idle
}
}
type SplitStateSQLRepr = Int2;
impl ToSql<SplitStateSQLRepr, Pg> for SplitState {
fn to_sql<'a>(
&'a self,
out: &'a mut diesel::serialize::Output<Pg>,
) -> diesel::serialize::Result {
let raw_value: i16 = *self as i16;
let mut new_out = out.reborrow();
ToSql::<SplitStateSQLRepr, Pg>::to_sql(&raw_value, &mut new_out)
}
}
impl FromSql<SplitStateSQLRepr, Pg> for SplitState {
fn from_sql(pg_value: PgValue) -> diesel::deserialize::Result<Self> {
match FromSql::<SplitStateSQLRepr, Pg>::from_sql(pg_value).map(|v| match v {
0 => Some(Self::Idle),
1 => Some(Self::Splitting),
_ => None,
})? {
Some(v) => Ok(v),
None => Err(format!("Invalid SplitState value, was: {:?}", pg_value.as_bytes()).into()),
}
}
}

View File

@@ -20,7 +20,6 @@ diesel::table! {
generation -> Int4,
generation_pageserver -> Int8,
placement_policy -> Varchar,
splitting -> Int2,
config -> Text,
}
}

View File

@@ -1,6 +1,5 @@
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap, HashSet},
collections::{BTreeMap, HashMap},
str::FromStr,
sync::Arc,
time::{Duration, Instant},
@@ -24,14 +23,13 @@ use pageserver_api::{
models::{
LocationConfig, LocationConfigMode, ShardParameters, TenantConfig, TenantCreateRequest,
TenantLocationConfigRequest, TenantLocationConfigResponse, TenantShardLocation,
TenantShardSplitRequest, TenantShardSplitResponse, TimelineCreateRequest, TimelineInfo,
TimelineCreateRequest, TimelineInfo,
},
shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId},
};
use pageserver_client::mgmt_api;
use tokio_util::sync::CancellationToken;
use utils::{
backoff,
completion::Barrier,
generation::Generation,
http::error::ApiError,
@@ -42,11 +40,7 @@ use utils::{
use crate::{
compute_hook::{self, ComputeHook},
node::Node,
persistence::{
split_state::SplitState, DatabaseError, NodePersistence, Persistence,
TenantShardPersistence,
},
reconciler::attached_location_conf,
persistence::{DatabaseError, NodePersistence, Persistence, TenantShardPersistence},
scheduler::Scheduler,
tenant_state::{
IntentState, ObservedState, ObservedStateLocation, ReconcileResult, ReconcileWaitError,
@@ -109,9 +103,7 @@ impl From<DatabaseError> for ApiError {
match err {
DatabaseError::Query(e) => ApiError::InternalServerError(e.into()),
// FIXME: ApiError doesn't have an Unavailable variant, but ShuttingDown maps to 503.
DatabaseError::Connection(_) | DatabaseError::ConnectionPool(_) => {
ApiError::ShuttingDown
}
DatabaseError::Connection(_e) => ApiError::ShuttingDown,
DatabaseError::Logical(reason) => {
ApiError::InternalServerError(anyhow::anyhow!(reason))
}
@@ -151,71 +143,31 @@ impl Service {
// indeterminate, same as in [`ObservedStateLocation`])
let mut observed = HashMap::new();
let mut nodes_online = HashSet::new();
// TODO: give Service a cancellation token for clean shutdown
let cancel = CancellationToken::new();
let nodes = {
let locked = self.inner.read().unwrap();
locked.nodes.clone()
};
// TODO: issue these requests concurrently
{
let nodes = {
let locked = self.inner.read().unwrap();
locked.nodes.clone()
};
for node in nodes.values() {
let http_client = reqwest::ClientBuilder::new()
.timeout(Duration::from_secs(5))
.build()
.expect("Failed to construct HTTP client");
let client = mgmt_api::Client::from_client(
http_client,
node.base_url(),
self.config.jwt_token.as_deref(),
);
for node in nodes.values() {
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
fn is_fatal(e: &mgmt_api::Error) -> bool {
use mgmt_api::Error::*;
match e {
ReceiveBody(_) | ReceiveErrorBody(_) => false,
ApiError(StatusCode::SERVICE_UNAVAILABLE, _)
| ApiError(StatusCode::GATEWAY_TIMEOUT, _)
| ApiError(StatusCode::REQUEST_TIMEOUT, _) => false,
ApiError(_, _) => true,
}
tracing::info!("Scanning shards on node {}...", node.id);
match client.list_location_config().await {
Err(e) => {
tracing::warn!("Could not contact pageserver {} ({e})", node.id);
// TODO: be more tolerant, apply a generous 5-10 second timeout with retries, in case
// pageserver is being restarted at the same time as we are
}
Ok(listing) => {
tracing::info!(
"Received {} shard statuses from pageserver {}, setting it to Active",
listing.tenant_shards.len(),
node.id
);
let list_response = backoff::retry(
|| client.list_location_config(),
is_fatal,
1,
5,
"Location config listing",
&cancel,
)
.await;
let Some(list_response) = list_response else {
tracing::info!("Shutdown during startup_reconcile");
return;
};
tracing::info!("Scanning shards on node {}...", node.id);
match list_response {
Err(e) => {
tracing::warn!("Could not contact pageserver {} ({e})", node.id);
// TODO: be more tolerant, do some retries, in case
// pageserver is being restarted at the same time as we are
}
Ok(listing) => {
tracing::info!(
"Received {} shard statuses from pageserver {}, setting it to Active",
listing.tenant_shards.len(),
node.id
);
nodes_online.insert(node.id);
for (tenant_shard_id, conf_opt) in listing.tenant_shards {
observed.insert(tenant_shard_id, (node.id, conf_opt));
}
for (tenant_shard_id, conf_opt) in listing.tenant_shards {
observed.insert(tenant_shard_id, (node.id, conf_opt));
}
}
}
@@ -226,19 +178,8 @@ impl Service {
let mut compute_notifications = Vec::new();
// Populate intent and observed states for all tenants, based on reported state on pageservers
let (shard_count, nodes) = {
let shard_count = {
let mut locked = self.inner.write().unwrap();
// Mark nodes online if they responded to us: nodes are offline by default after a restart.
let mut nodes = (*locked.nodes).clone();
for (node_id, node) in nodes.iter_mut() {
if nodes_online.contains(node_id) {
node.availability = NodeAvailability::Active;
}
}
locked.nodes = Arc::new(nodes);
let nodes = locked.nodes.clone();
for (tenant_shard_id, (node_id, observed_loc)) in observed {
let Some(tenant_state) = locked.tenants.get_mut(&tenant_shard_id) else {
cleanup.push((tenant_shard_id, node_id));
@@ -270,7 +211,7 @@ impl Service {
}
}
(locked.tenants.len(), nodes)
locked.tenants.len()
};
// TODO: if any tenant's intent now differs from its loaded generation_pageserver, we should clear that
@@ -331,8 +272,9 @@ impl Service {
let stream = futures::stream::iter(compute_notifications.into_iter())
.map(|(tenant_shard_id, node_id)| {
let compute_hook = compute_hook.clone();
let cancel = cancel.clone();
async move {
// TODO: give Service a cancellation token for clean shutdown
let cancel = CancellationToken::new();
if let Err(e) = compute_hook.notify(tenant_shard_id, node_id, &cancel).await {
tracing::error!(
tenant_shard_id=%tenant_shard_id,
@@ -438,7 +380,7 @@ impl Service {
))),
config,
persistence,
startup_complete: startup_complete.clone(),
startup_complete,
});
let result_task_this = this.clone();
@@ -532,7 +474,6 @@ impl Service {
generation_pageserver: i64::MAX,
placement_policy: serde_json::to_string(&PlacementPolicy::default()).unwrap(),
config: serde_json::to_string(&TenantConfig::default()).unwrap(),
splitting: SplitState::default(),
};
match self.persistence.insert_tenant_shards(vec![tsp]).await {
@@ -775,7 +716,6 @@ impl Service {
generation_pageserver: i64::MAX,
placement_policy: serde_json::to_string(&placement_policy).unwrap(),
config: serde_json::to_string(&create_req.config).unwrap(),
splitting: SplitState::default(),
})
.collect();
self.persistence
@@ -1035,10 +975,6 @@ impl Service {
}
};
// TODO: if we timeout/fail on reconcile, we should still succeed this request,
// because otherwise a broken compute hook causes a feedback loop where
// location_config returns 500 and gets retried forever.
if let Some(create_req) = maybe_create {
let create_resp = self.tenant_create(create_req).await?;
result.shards = create_resp
@@ -1051,15 +987,7 @@ impl Service {
.collect();
} else {
// This was an update, wait for reconciliation
if let Err(e) = self.await_waiters(waiters).await {
// Do not treat a reconcile error as fatal: we have already applied any requested
// Intent changes, and the reconcile can fail for external reasons like unavailable
// compute notification API. In these cases, it is important that we do not
// cause the cloud control plane to retry forever on this API.
tracing::warn!(
"Failed to reconcile after /location_config: {e}, returning success anyway"
);
}
self.await_waiters(waiters).await?;
}
Ok(result)
@@ -1162,7 +1090,6 @@ impl Service {
self.ensure_attached_wait(tenant_id).await?;
// TODO: refuse to do this if shard splitting is in progress
// (https://github.com/neondatabase/neon/issues/6676)
let targets = {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
@@ -1243,7 +1170,6 @@ impl Service {
self.ensure_attached_wait(tenant_id).await?;
// TODO: refuse to do this if shard splitting is in progress
// (https://github.com/neondatabase/neon/issues/6676)
let targets = {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
@@ -1416,326 +1342,6 @@ impl Service {
})
}
pub(crate) async fn tenant_shard_split(
&self,
tenant_id: TenantId,
split_req: TenantShardSplitRequest,
) -> Result<TenantShardSplitResponse, ApiError> {
let mut policy = None;
let mut shard_ident = None;
// TODO: put a cancellation token on Service for clean shutdown
let cancel = CancellationToken::new();
// A parent shard which will be split
struct SplitTarget {
parent_id: TenantShardId,
node: Node,
child_ids: Vec<TenantShardId>,
}
// Validate input, and calculate which shards we will create
let (old_shard_count, targets, compute_hook) = {
let locked = self.inner.read().unwrap();
let pageservers = locked.nodes.clone();
let mut targets = Vec::new();
// In case this is a retry, count how many already-split shards we found
let mut children_found = Vec::new();
let mut old_shard_count = None;
for (tenant_shard_id, shard) in
locked.tenants.range(TenantShardId::tenant_range(tenant_id))
{
match shard.shard.count.0.cmp(&split_req.new_shard_count) {
Ordering::Equal => {
// Already split this
children_found.push(*tenant_shard_id);
continue;
}
Ordering::Greater => {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Requested count {} but already have shards at count {}",
split_req.new_shard_count,
shard.shard.count.0
)));
}
Ordering::Less => {
// Fall through: this shard has lower count than requested,
// is a candidate for splitting.
}
}
match old_shard_count {
None => old_shard_count = Some(shard.shard.count),
Some(old_shard_count) => {
if old_shard_count != shard.shard.count {
// We may hit this case if a caller asked for two splits to
// different sizes, before the first one is complete.
// e.g. 1->2, 2->4, where the 4 call comes while we have a mixture
// of shard_count=1 and shard_count=2 shards in the map.
return Err(ApiError::Conflict(
"Cannot split, currently mid-split".to_string(),
));
}
}
}
if policy.is_none() {
policy = Some(shard.policy.clone());
}
if shard_ident.is_none() {
shard_ident = Some(shard.shard);
}
if tenant_shard_id.shard_count == ShardCount(split_req.new_shard_count) {
tracing::info!(
"Tenant shard {} already has shard count {}",
tenant_shard_id,
split_req.new_shard_count
);
continue;
}
let node_id =
shard
.intent
.attached
.ok_or(ApiError::BadRequest(anyhow::anyhow!(
"Cannot split a tenant that is not attached"
)))?;
let node = pageservers
.get(&node_id)
.expect("Pageservers may not be deleted while referenced");
// TODO: if any reconciliation is currently in progress for this shard, wait for it.
targets.push(SplitTarget {
parent_id: *tenant_shard_id,
node: node.clone(),
child_ids: tenant_shard_id.split(ShardCount(split_req.new_shard_count)),
});
}
if targets.is_empty() {
if children_found.len() == split_req.new_shard_count as usize {
return Ok(TenantShardSplitResponse {
new_shards: children_found,
});
} else {
// No shards found to split, and no existing children found: the
// tenant doesn't exist at all.
return Err(ApiError::NotFound(
anyhow::anyhow!("Tenant {} not found", tenant_id).into(),
));
}
}
(old_shard_count, targets, locked.compute_hook.clone())
};
// unwrap safety: we would have returned above if we didn't find at least one shard to split
let old_shard_count = old_shard_count.unwrap();
let shard_ident = shard_ident.unwrap();
let policy = policy.unwrap();
// FIXME: we have dropped self.inner lock, and not yet written anything to the database: another
// request could occur here, deleting or mutating the tenant. begin_shard_split checks that the
// parent shards exist as expected, but it would be neater to do the above pre-checks within the
// same database transaction rather than pre-check in-memory and then maybe-fail the database write.
// (https://github.com/neondatabase/neon/issues/6676)
// Before creating any new child shards in memory or on the pageservers, persist them: this
// enables us to ensure that we will always be able to clean up if something goes wrong. This also
// acts as the protection against two concurrent attempts to split: one of them will get a database
// error trying to insert the child shards.
let mut child_tsps = Vec::new();
for target in &targets {
let mut this_child_tsps = Vec::new();
for child in &target.child_ids {
let mut child_shard = shard_ident;
child_shard.number = child.shard_number;
child_shard.count = child.shard_count;
this_child_tsps.push(TenantShardPersistence {
tenant_id: child.tenant_id.to_string(),
shard_number: child.shard_number.0 as i32,
shard_count: child.shard_count.0 as i32,
shard_stripe_size: shard_ident.stripe_size.0 as i32,
// Note: this generation is a placeholder, [`Persistence::begin_shard_split`] will
// populate the correct generation as part of its transaction, to protect us
// against racing with changes in the state of the parent.
generation: 0,
generation_pageserver: target.node.id.0 as i64,
placement_policy: serde_json::to_string(&policy).unwrap(),
// TODO: get the config out of the map
config: serde_json::to_string(&TenantConfig::default()).unwrap(),
splitting: SplitState::Splitting,
});
}
child_tsps.push((target.parent_id, this_child_tsps));
}
if let Err(e) = self
.persistence
.begin_shard_split(old_shard_count, tenant_id, child_tsps)
.await
{
match e {
DatabaseError::Query(diesel::result::Error::DatabaseError(
DatabaseErrorKind::UniqueViolation,
_,
)) => {
// Inserting a child shard violated a unique constraint: we raced with another call to
// this function
tracing::warn!("Conflicting attempt to split {tenant_id}: {e}");
return Err(ApiError::Conflict("Tenant is already splitting".into()));
}
_ => return Err(ApiError::InternalServerError(e.into())),
}
}
// FIXME: we have now committed the shard split state to the database, so any subsequent
// failure needs to roll it back. We will later wrap this function in logic to roll back
// the split if it fails.
// (https://github.com/neondatabase/neon/issues/6676)
// TODO: issue split calls concurrently (this only matters once we're splitting
// N>1 shards into M shards -- initially we're usually splitting 1 shard into N).
for target in &targets {
let SplitTarget {
parent_id,
node,
child_ids,
} = target;
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
let response = client
.tenant_shard_split(
*parent_id,
TenantShardSplitRequest {
new_shard_count: split_req.new_shard_count,
},
)
.await
.map_err(|e| ApiError::Conflict(format!("Failed to split {}: {}", parent_id, e)))?;
tracing::info!(
"Split {} into {}",
parent_id,
response
.new_shards
.iter()
.map(|s| format!("{:?}", s))
.collect::<Vec<_>>()
.join(",")
);
if &response.new_shards != child_ids {
// This should never happen: the pageserver should agree with us on how shard splits work.
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"Splitting shard {} resulted in unexpected IDs: {:?} (expected {:?})",
parent_id,
response.new_shards,
child_ids
)));
}
}
// TODO: if the pageserver restarted concurrently with our split API call,
// the actual generation of the child shard might differ from the generation
// we expect it to have. In order for our in-database generation to end up
// correct, we should carry the child generation back in the response and apply it here
// in complete_shard_split (and apply the correct generation in memory)
// (or, we can carry generation in the request and reject the request if
// it doesn't match, but that requires more retry logic on this side)
self.persistence
.complete_shard_split(tenant_id, old_shard_count)
.await?;
// Replace all the shards we just split with their children
let mut response = TenantShardSplitResponse {
new_shards: Vec::new(),
};
let mut child_locations = Vec::new();
{
let mut locked = self.inner.write().unwrap();
for target in targets {
let SplitTarget {
parent_id,
node: _node,
child_ids,
} = target;
let (pageserver, generation, config) = {
let old_state = locked
.tenants
.remove(&parent_id)
.expect("It was present, we just split it");
(
old_state.intent.attached.unwrap(),
old_state.generation,
old_state.config.clone(),
)
};
locked.tenants.remove(&parent_id);
for child in child_ids {
let mut child_shard = shard_ident;
child_shard.number = child.shard_number;
child_shard.count = child.shard_count;
let mut child_observed: HashMap<NodeId, ObservedStateLocation> = HashMap::new();
child_observed.insert(
pageserver,
ObservedStateLocation {
conf: Some(attached_location_conf(generation, &child_shard, &config)),
},
);
let mut child_state = TenantState::new(child, child_shard, policy.clone());
child_state.intent = IntentState::single(Some(pageserver));
child_state.observed = ObservedState {
locations: child_observed,
};
child_state.generation = generation;
child_state.config = config.clone();
child_locations.push((child, pageserver));
locked.tenants.insert(child, child_state);
response.new_shards.push(child);
}
}
}
// Send compute notifications for all the new shards
let mut failed_notifications = Vec::new();
for (child_id, child_ps) in child_locations {
if let Err(e) = compute_hook.notify(child_id, child_ps, &cancel).await {
tracing::warn!("Failed to update compute of {}->{} during split, proceeding anyway to complete split ({e})",
child_id, child_ps);
failed_notifications.push(child_id);
}
}
// If we failed any compute notifications, make a note to retry later.
if !failed_notifications.is_empty() {
let mut locked = self.inner.write().unwrap();
for failed in failed_notifications {
if let Some(shard) = locked.tenants.get_mut(&failed) {
shard.pending_compute_notification = true;
}
}
}
Ok(response)
}
pub(crate) async fn tenant_shard_migrate(
&self,
tenant_shard_id: TenantShardId,
@@ -1804,45 +1410,6 @@ impl Service {
Ok(TenantShardMigrateResponse {})
}
/// This is for debug/support only: we simply drop all state for a tenant, without
/// detaching or deleting it on pageservers.
pub(crate) async fn tenant_drop(&self, tenant_id: TenantId) -> Result<(), ApiError> {
self.persistence.delete_tenant(tenant_id).await?;
let mut locked = self.inner.write().unwrap();
let mut shards = Vec::new();
for (tenant_shard_id, _) in locked.tenants.range(TenantShardId::tenant_range(tenant_id)) {
shards.push(*tenant_shard_id);
}
for shard in shards {
locked.tenants.remove(&shard);
}
Ok(())
}
/// This is for debug/support only: we simply drop all state for a tenant, without
/// detaching or deleting it on pageservers. We do not try and re-schedule any
/// tenants that were on this node.
///
/// TODO: proper node deletion API that unhooks things more gracefully
pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> {
self.persistence.delete_node(node_id).await?;
let mut locked = self.inner.write().unwrap();
for shard in locked.tenants.values_mut() {
shard.deref_node(node_id);
}
let mut nodes = (*locked.nodes).clone();
nodes.remove(&node_id);
locked.nodes = Arc::new(nodes);
Ok(())
}
pub(crate) async fn node_list(&self) -> Result<Vec<NodePersistence>, ApiError> {
// It is convenient to avoid taking the big lock and converting Node to a serializable
// structure, by fetching from storage instead of reading in-memory state.

View File

@@ -193,13 +193,6 @@ impl IntentState {
result
}
pub(crate) fn single(node_id: Option<NodeId>) -> Self {
Self {
attached: node_id,
secondary: vec![],
}
}
/// When a node goes offline, we update intents to avoid using it
/// as their attached pageserver.
///
@@ -293,9 +286,6 @@ impl TenantState {
// self.intent refers to pageservers that are offline, and pick other
// pageservers if so.
// TODO: respect the splitting bit on tenants: if they are currently splitting then we may not
// change their attach location.
// Build the set of pageservers already in use by this tenant, to avoid scheduling
// more work on the same pageservers we're already using.
let mut used_pageservers = self.intent.all_pageservers();
@@ -534,18 +524,4 @@ impl TenantState {
seq: self.sequence,
})
}
// If we had any state at all referring to this node ID, drop it. Does not
// attempt to reschedule.
pub(crate) fn deref_node(&mut self, node_id: NodeId) {
if self.intent.attached == Some(node_id) {
self.intent.attached = None;
}
self.intent.secondary.retain(|n| n != &node_id);
self.observed.locations.remove(&node_id);
debug_assert!(!self.intent.all_pageservers().contains(&node_id));
}
}

View File

@@ -1,17 +1,20 @@
use crate::{background_process, local_env::LocalEnv};
use camino::{Utf8Path, Utf8PathBuf};
use diesel::{
backend::Backend,
query_builder::{AstPass, QueryFragment, QueryId},
Connection, PgConnection, QueryResult, RunQueryDsl,
};
use diesel_migrations::{HarnessWithOutput, MigrationHarness};
use hyper::Method;
use pageserver_api::{
models::{
ShardParameters, TenantCreateRequest, TenantShardSplitRequest, TenantShardSplitResponse,
TimelineCreateRequest, TimelineInfo,
},
models::{ShardParameters, TenantCreateRequest, TimelineCreateRequest, TimelineInfo},
shard::TenantShardId,
};
use pageserver_client::mgmt_api::ResponseErrorMessageExt;
use postgres_backend::AuthType;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::str::FromStr;
use std::{env, str::FromStr};
use tokio::process::Command;
use tracing::instrument;
use url::Url;
@@ -267,6 +270,37 @@ impl AttachmentService {
.expect("non-Unicode path")
}
/// In order to access database migrations, we need to find the Neon source tree
async fn find_source_root(&self) -> anyhow::Result<Utf8PathBuf> {
// We assume that either prd or our binary is in the source tree. The former is usually
// true for automated test runners, the latter is usually true for developer workstations. Often
// both are true, which is fine.
let candidate_start_points = [
// Current working directory
Utf8PathBuf::from_path_buf(std::env::current_dir()?).unwrap(),
// Directory containing the binary we're running inside
Utf8PathBuf::from_path_buf(env::current_exe()?.parent().unwrap().to_owned()).unwrap(),
];
// For each candidate start point, search through ancestors looking for a neon.git source tree root
for start_point in &candidate_start_points {
// Start from the build dir: assumes we are running out of a built neon source tree
for path in start_point.ancestors() {
// A crude approximation: the root of the source tree is whatever contains a "control_plane"
// subdirectory.
let control_plane = path.join("control_plane");
if tokio::fs::try_exists(&control_plane).await? {
return Ok(path.to_owned());
}
}
}
// Fall-through
Err(anyhow::anyhow!(
"Could not find control_plane src dir, after searching ancestors of {candidate_start_points:?}"
))
}
/// Find the directory containing postgres binaries, such as `initdb` and `pg_ctl`
///
/// This usually uses ATTACHMENT_SERVICE_POSTGRES_VERSION of postgres, but will fall back
@@ -306,32 +340,69 @@ impl AttachmentService {
///
/// Returns the database url
pub async fn setup_database(&self) -> anyhow::Result<String> {
const DB_NAME: &str = "attachment_service";
let database_url = format!("postgresql://localhost:{}/{DB_NAME}", self.postgres_port);
let database_url = format!(
"postgresql://localhost:{}/attachment_service",
self.postgres_port
);
println!("Running attachment service database setup...");
fn change_database_of_url(database_url: &str, default_database: &str) -> (String, String) {
let base = ::url::Url::parse(database_url).unwrap();
let database = base.path_segments().unwrap().last().unwrap().to_owned();
let mut new_url = base.join(default_database).unwrap();
new_url.set_query(base.query());
(database, new_url.into())
}
let pg_bin_dir = self.get_pg_bin_dir().await?;
let createdb_path = pg_bin_dir.join("createdb");
let output = Command::new(&createdb_path)
.args([
"-h",
"localhost",
"-p",
&format!("{}", self.postgres_port),
&DB_NAME,
])
.output()
.await
.expect("Failed to spawn createdb");
#[derive(Debug, Clone)]
pub struct CreateDatabaseStatement {
db_name: String,
}
if !output.status.success() {
let stderr = String::from_utf8(output.stderr).expect("Non-UTF8 output from createdb");
if stderr.contains("already exists") {
tracing::info!("Database {DB_NAME} already exists");
} else {
anyhow::bail!("createdb failed with status {}: {stderr}", output.status);
impl CreateDatabaseStatement {
pub fn new(db_name: &str) -> Self {
CreateDatabaseStatement {
db_name: db_name.to_owned(),
}
}
}
impl<DB: Backend> QueryFragment<DB> for CreateDatabaseStatement {
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
out.push_sql("CREATE DATABASE ");
out.push_identifier(&self.db_name)?;
Ok(())
}
}
impl<Conn> RunQueryDsl<Conn> for CreateDatabaseStatement {}
impl QueryId for CreateDatabaseStatement {
type QueryId = ();
const HAS_STATIC_QUERY_ID: bool = false;
}
if PgConnection::establish(&database_url).is_err() {
let (database, postgres_url) = change_database_of_url(&database_url, "postgres");
println!("Creating database: {database}");
let mut conn = PgConnection::establish(&postgres_url)?;
CreateDatabaseStatement::new(&database).execute(&mut conn)?;
}
let mut conn = PgConnection::establish(&database_url)?;
let migrations_dir = self
.find_source_root()
.await?
.join("control_plane/attachment_service/migrations");
let migrations = diesel_migrations::FileBasedMigrations::from_path(migrations_dir)?;
println!("Running migrations in {}", migrations.path().display());
HarnessWithOutput::write_to_stdout(&mut conn)
.run_pending_migrations(migrations)
.map(|_| ())
.map_err(|e| anyhow::anyhow!(e))?;
println!("Migrations complete");
Ok(database_url)
}
@@ -577,7 +648,7 @@ impl AttachmentService {
) -> anyhow::Result<TenantShardMigrateResponse> {
self.dispatch(
Method::PUT,
format!("control/v1/tenant/{tenant_shard_id}/migrate"),
format!("tenant/{tenant_shard_id}/migrate"),
Some(TenantShardMigrateRequest {
tenant_shard_id,
node_id,
@@ -586,20 +657,6 @@ impl AttachmentService {
.await
}
#[instrument(skip(self), fields(%tenant_id, %new_shard_count))]
pub async fn tenant_split(
&self,
tenant_id: TenantId,
new_shard_count: u8,
) -> anyhow::Result<TenantShardSplitResponse> {
self.dispatch(
Method::PUT,
format!("control/v1/tenant/{tenant_id}/shard_split"),
Some(TenantShardSplitRequest { new_shard_count }),
)
.await
}
#[instrument(skip_all, fields(node_id=%req.node_id))]
pub async fn node_register(&self, req: NodeRegisterRequest) -> anyhow::Result<()> {
self.dispatch::<_, ()>(Method::POST, "control/v1/node".to_string(), Some(req))

View File

@@ -72,6 +72,7 @@ where
let log_path = datadir.join(format!("{process_name}.log"));
let process_log_file = fs::OpenOptions::new()
.create(true)
.write(true)
.append(true)
.open(&log_path)
.with_context(|| {

View File

@@ -575,26 +575,6 @@ async fn handle_tenant(
println!("{tenant_table}");
println!("{shard_table}");
}
Some(("shard-split", matches)) => {
let tenant_id = get_tenant_id(matches, env)?;
let shard_count: u8 = matches.get_one::<u8>("shard-count").cloned().unwrap_or(0);
let attachment_service = AttachmentService::from_env(env);
let result = attachment_service
.tenant_split(tenant_id, shard_count)
.await?;
println!(
"Split tenant {} into shards {}",
tenant_id,
result
.new_shards
.iter()
.map(|s| format!("{:?}", s))
.collect::<Vec<_>>()
.join(",")
);
}
Some((sub_name, _)) => bail!("Unexpected tenant subcommand '{}'", sub_name),
None => bail!("no tenant subcommand provided"),
}
@@ -1014,13 +994,12 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
.get_one::<String>("endpoint_id")
.ok_or_else(|| anyhow!("No endpoint ID was provided to stop"))?;
let destroy = sub_args.get_flag("destroy");
let mode = sub_args.get_one::<String>("mode").expect("has a default");
let endpoint = cplane
.endpoints
.get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
endpoint.stop(mode, destroy)?;
endpoint.stop(destroy)?;
}
_ => bail!("Unexpected endpoint subcommand '{sub_name}'"),
@@ -1304,7 +1283,7 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) {
match ComputeControlPlane::load(env.clone()) {
Ok(cplane) => {
for (_k, node) in cplane.endpoints {
if let Err(e) = node.stop(if immediate { "immediate" } else { "fast " }, false) {
if let Err(e) = node.stop(false) {
eprintln!("postgres stop failed: {e:#}");
}
}
@@ -1545,11 +1524,6 @@ fn cli() -> Command {
.subcommand(Command::new("status")
.about("Human readable summary of the tenant's shards and attachment locations")
.arg(tenant_id_arg.clone()))
.subcommand(Command::new("shard-split")
.about("Increase the number of shards in the tenant")
.arg(tenant_id_arg.clone())
.arg(Arg::new("shard-count").value_parser(value_parser!(u8)).long("shard-count").action(ArgAction::Set).help("Number of shards in the new tenant (default 1)"))
)
)
.subcommand(
Command::new("pageserver")
@@ -1653,16 +1627,7 @@ fn cli() -> Command {
.long("destroy")
.action(ArgAction::SetTrue)
.required(false)
)
.arg(
Arg::new("mode")
.help("Postgres shutdown mode, passed to \"pg_ctl -m <mode>\"")
.long("mode")
.action(ArgAction::Set)
.required(false)
.value_parser(["smart", "fast", "immediate"])
.default_value("fast")
)
)
)
)

View File

@@ -761,8 +761,22 @@ impl Endpoint {
}
}
pub fn stop(&self, mode: &str, destroy: bool) -> Result<()> {
self.pg_ctl(&["-m", mode, "stop"], &None)?;
pub fn stop(&self, destroy: bool) -> Result<()> {
// If we are going to destroy data directory,
// use immediate shutdown mode, otherwise,
// shutdown gracefully to leave the data directory sane.
//
// Postgres is always started from scratch, so stop
// without destroy only used for testing and debugging.
//
self.pg_ctl(
if destroy {
&["-m", "immediate", "stop"]
} else {
&["stop"]
},
&None,
)?;
// Also wait for the compute_ctl process to die. It might have some
// cleanup work to do after postgres stops, like syncing safekeepers,

View File

@@ -90,8 +90,8 @@ pub enum ComputeFeature {
/// track short-lived connections as user activity.
ActivityMonitorExperimental,
/// Pre-install and initialize anon extension for every database in the cluster
AnonExtension,
/// Enable running migrations
Migrations,
/// This is a special feature flag that is used to represent unknown feature flags.
/// Basically all unknown to enum flags are represented as this one. See unit test

View File

@@ -13,9 +13,6 @@ twox-hash.workspace = true
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
procfs.workspace = true
[dev-dependencies]
rand = "0.8"
rand_distr = "0.4.3"

View File

@@ -31,8 +31,6 @@ pub use wrappers::{CountedReader, CountedWriter};
mod hll;
pub mod metric_vec_duration;
pub use hll::{HyperLogLog, HyperLogLogVec};
#[cfg(target_os = "linux")]
pub mod more_process_metrics;
pub type UIntGauge = GenericGauge<AtomicU64>;
pub type UIntGaugeVec = GenericGaugeVec<AtomicU64>;

View File

@@ -1,54 +0,0 @@
//! process metrics that the [`::prometheus`] crate doesn't provide.
// This module has heavy inspiration from the prometheus crate's `process_collector.rs`.
use crate::UIntGauge;
pub struct Collector {
descs: Vec<prometheus::core::Desc>,
vmlck: crate::UIntGauge,
}
const NMETRICS: usize = 1;
impl prometheus::core::Collector for Collector {
fn desc(&self) -> Vec<&prometheus::core::Desc> {
self.descs.iter().collect()
}
fn collect(&self) -> Vec<prometheus::proto::MetricFamily> {
let Ok(myself) = procfs::process::Process::myself() else {
return vec![];
};
let mut mfs = Vec::with_capacity(NMETRICS);
if let Ok(status) = myself.status() {
if let Some(vmlck) = status.vmlck {
self.vmlck.set(vmlck);
mfs.extend(self.vmlck.collect())
}
}
mfs
}
}
impl Collector {
pub fn new() -> Self {
let mut descs = Vec::new();
let vmlck =
UIntGauge::new("libmetrics_process_status_vmlck", "/proc/self/status vmlck").unwrap();
descs.extend(
prometheus::core::Collector::desc(&vmlck)
.into_iter()
.cloned(),
);
Self { descs, vmlck }
}
}
impl Default for Collector {
fn default() -> Self {
Self::new()
}
}

View File

@@ -192,16 +192,6 @@ pub struct TimelineCreateRequest {
pub pg_version: Option<u32>,
}
#[derive(Serialize, Deserialize)]
pub struct TenantShardSplitRequest {
pub new_shard_count: u8,
}
#[derive(Serialize, Deserialize)]
pub struct TenantShardSplitResponse {
pub new_shards: Vec<TenantShardId>,
}
/// Parameters that apply to all shards in a tenant. Used during tenant creation.
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
@@ -659,27 +649,6 @@ pub struct WalRedoManagerStatus {
pub pid: Option<u32>,
}
pub mod virtual_file {
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Hash,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
Debug,
)]
#[strum(serialize_all = "kebab-case")]
pub enum IoEngineKind {
StdFs,
#[cfg(target_os = "linux")]
TokioEpollUring,
}
}
// Wrapped in libpq CopyData
#[derive(PartialEq, Eq, Debug)]
pub enum PagestreamFeMessage {

View File

@@ -88,36 +88,12 @@ impl TenantShardId {
pub fn is_unsharded(&self) -> bool {
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
}
/// Convenience for dropping the tenant_id and just getting the ShardIndex: this
/// is useful when logging from code that is already in a span that includes tenant ID, to
/// keep messages reasonably terse.
pub fn to_index(&self) -> ShardIndex {
ShardIndex {
shard_number: self.shard_number,
shard_count: self.shard_count,
}
}
/// Calculate the children of this TenantShardId when splitting the overall tenant into
/// the given number of shards.
pub fn split(&self, new_shard_count: ShardCount) -> Vec<TenantShardId> {
let effective_old_shard_count = std::cmp::max(self.shard_count.0, 1);
let mut child_shards = Vec::new();
for shard_number in 0..ShardNumber(new_shard_count.0).0 {
// Key mapping is based on a round robin mapping of key hash modulo shard count,
// so our child shards are the ones which the same keys would map to.
if shard_number % effective_old_shard_count == self.shard_number.0 {
child_shards.push(TenantShardId {
tenant_id: self.tenant_id,
shard_number: ShardNumber(shard_number),
shard_count: new_shard_count,
})
}
}
child_shards
}
}
/// Formatting helper
@@ -817,108 +793,4 @@ mod tests {
let shard = key_to_shard_number(ShardCount(10), DEFAULT_STRIPE_SIZE, &key);
assert_eq!(shard, ShardNumber(8));
}
#[test]
fn shard_id_split() {
let tenant_id = TenantId::generate();
let parent = TenantShardId::unsharded(tenant_id);
// Unsharded into 2
assert_eq!(
parent.split(ShardCount(2)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(0)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(1)
}
]
);
// Unsharded into 4
assert_eq!(
parent.split(ShardCount(4)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(0)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(1)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(2)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(3)
}
]
);
// count=1 into 2 (check this works the same as unsharded.)
let parent = TenantShardId {
tenant_id,
shard_count: ShardCount(1),
shard_number: ShardNumber(0),
};
assert_eq!(
parent.split(ShardCount(2)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(0)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(1)
}
]
);
// count=2 into count=8
let parent = TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(1),
};
assert_eq!(
parent.split(ShardCount(8)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(1)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(3)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(5)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(7)
},
]
);
}
}

View File

@@ -13,6 +13,5 @@ rand.workspace = true
tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true
serde.workspace = true
workspace_hack.workspace = true

View File

@@ -7,7 +7,6 @@ pub mod framed;
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
// re-export for use in utils pageserver_feedback.rs
@@ -124,7 +123,7 @@ impl StartupMessageParams {
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct CancelKeyData {
pub backend_pid: i32,
pub cancel_key: i32,

View File

@@ -191,7 +191,6 @@ impl RemoteStorage for AzureBlobStorage {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> anyhow::Result<Listing, DownloadError> {
// get the passed prefix or if it is not set use prefix_in_bucket value
let list_prefix = prefix
@@ -224,8 +223,6 @@ impl RemoteStorage for AzureBlobStorage {
let mut response = builder.into_stream();
let mut res = Listing::default();
// NonZeroU32 doesn't support subtraction apparently
let mut max_keys = max_keys.map(|mk| mk.get());
while let Some(l) = response.next().await {
let entry = l.map_err(to_download_error)?;
let prefix_iter = entry
@@ -238,18 +235,7 @@ impl RemoteStorage for AzureBlobStorage {
.blobs
.blobs()
.map(|k| self.name_to_relative_path(&k.name));
for key in blob_iter {
res.keys.push(key);
if let Some(mut mk) = max_keys {
assert!(mk > 0);
mk -= 1;
if mk == 0 {
return Ok(res); // limit reached
}
max_keys = Some(mk);
}
}
res.keys.extend(blob_iter);
}
Ok(res)
}

View File

@@ -13,15 +13,9 @@ mod azure_blob;
mod local_fs;
mod s3_bucket;
mod simulate_failures;
mod support;
use std::{
collections::HashMap,
fmt::Debug,
num::{NonZeroU32, NonZeroUsize},
pin::Pin,
sync::Arc,
time::SystemTime,
collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc, time::SystemTime,
};
use anyhow::{bail, Context};
@@ -160,7 +154,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
prefix: Option<&RemotePath>,
) -> Result<Vec<RemotePath>, DownloadError> {
let result = self
.list(prefix, ListingMode::WithDelimiter, None)
.list(prefix, ListingMode::WithDelimiter)
.await?
.prefixes;
Ok(result)
@@ -176,17 +170,8 @@ pub trait RemoteStorage: Send + Sync + 'static {
/// whereas,
/// list_prefixes("foo/bar/") = ["cat", "dog"]
/// See `test_real_s3.rs` for more details.
///
/// max_keys limits max number of keys returned; None means unlimited.
async fn list_files(
&self,
prefix: Option<&RemotePath>,
max_keys: Option<NonZeroU32>,
) -> Result<Vec<RemotePath>, DownloadError> {
let result = self
.list(prefix, ListingMode::NoDelimiter, max_keys)
.await?
.keys;
async fn list_files(&self, prefix: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
let result = self.list(prefix, ListingMode::NoDelimiter).await?.keys;
Ok(result)
}
@@ -194,8 +179,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
&self,
prefix: Option<&RemotePath>,
_mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError>;
) -> anyhow::Result<Listing, DownloadError>;
/// Streams the local file contents into remote into the remote storage entry.
async fn upload(
@@ -285,19 +269,6 @@ impl std::fmt::Display for DownloadError {
impl std::error::Error for DownloadError {}
impl DownloadError {
/// Returns true if the error should not be retried with backoff
pub fn is_permanent(&self) -> bool {
use DownloadError::*;
match self {
BadInput(_) => true,
NotFound => true,
Cancelled => true,
Other(_) => false,
}
}
}
#[derive(Debug)]
pub enum TimeTravelError {
/// Validation or other error happened due to user input.
@@ -353,31 +324,24 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> anyhow::Result<Listing, DownloadError> {
match self {
Self::LocalFs(s) => s.list(prefix, mode, max_keys).await,
Self::AwsS3(s) => s.list(prefix, mode, max_keys).await,
Self::AzureBlob(s) => s.list(prefix, mode, max_keys).await,
Self::Unreliable(s) => s.list(prefix, mode, max_keys).await,
Self::LocalFs(s) => s.list(prefix, mode).await,
Self::AwsS3(s) => s.list(prefix, mode).await,
Self::AzureBlob(s) => s.list(prefix, mode).await,
Self::Unreliable(s) => s.list(prefix, mode).await,
}
}
// A function for listing all the files in a "directory"
// Example:
// list_files("foo/bar") = ["foo/bar/a.txt", "foo/bar/b.txt"]
//
// max_keys limits max number of keys returned; None means unlimited.
pub async fn list_files(
&self,
folder: Option<&RemotePath>,
max_keys: Option<NonZeroU32>,
) -> Result<Vec<RemotePath>, DownloadError> {
pub async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
match self {
Self::LocalFs(s) => s.list_files(folder, max_keys).await,
Self::AwsS3(s) => s.list_files(folder, max_keys).await,
Self::AzureBlob(s) => s.list_files(folder, max_keys).await,
Self::Unreliable(s) => s.list_files(folder, max_keys).await,
Self::LocalFs(s) => s.list_files(folder).await,
Self::AwsS3(s) => s.list_files(folder).await,
Self::AzureBlob(s) => s.list_files(folder).await,
Self::Unreliable(s) => s.list_files(folder).await,
}
}

View File

@@ -4,9 +4,7 @@
//! This storage used in tests, but can also be used in cases when a certain persistent
//! volume is mounted to the local FS.
use std::{
borrow::Cow, future::Future, io::ErrorKind, num::NonZeroU32, pin::Pin, time::SystemTime,
};
use std::{borrow::Cow, future::Future, io::ErrorKind, pin::Pin, time::SystemTime};
use anyhow::{bail, ensure, Context};
use bytes::Bytes;
@@ -20,7 +18,9 @@ use tokio_util::{io::ReaderStream, sync::CancellationToken};
use tracing::*;
use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty};
use crate::{Download, DownloadError, Listing, ListingMode, RemotePath, TimeTravelError};
use crate::{
Download, DownloadError, DownloadStream, Listing, ListingMode, RemotePath, TimeTravelError,
};
use super::{RemoteStorage, StorageMetadata};
@@ -164,7 +164,6 @@ impl RemoteStorage for LocalFs {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError> {
let mut result = Listing::default();
@@ -181,9 +180,6 @@ impl RemoteStorage for LocalFs {
!path.is_dir()
})
.collect();
if let Some(max_keys) = max_keys {
result.keys.truncate(max_keys.get() as usize);
}
return Ok(result);
}
@@ -369,33 +365,27 @@ impl RemoteStorage for LocalFs {
format!("Failed to open source file {target_path:?} to use in the download")
})
.map_err(DownloadError::Other)?;
let len = source
.metadata()
.await
.context("query file length")
.map_err(DownloadError::Other)?
.len();
source
.seek(io::SeekFrom::Start(start_inclusive))
.await
.context("Failed to seek to the range start in a local storage file")
.map_err(DownloadError::Other)?;
let metadata = self
.read_storage_metadata(&target_path)
.await
.map_err(DownloadError::Other)?;
let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive);
let source = ReaderStream::new(source);
let download_stream: DownloadStream = match end_exclusive {
Some(end_exclusive) => Box::pin(ReaderStream::new(
source.take(end_exclusive - start_inclusive),
)),
None => Box::pin(ReaderStream::new(source)),
};
Ok(Download {
metadata,
last_modified: None,
etag: None,
download_stream: Box::pin(source),
download_stream,
})
} else {
Err(DownloadError::NotFound)
@@ -524,8 +514,10 @@ mod fs_tests {
use futures_util::Stream;
use std::{collections::HashMap, io::Write};
async fn read_and_check_metadata(
async fn read_and_assert_remote_file_contents(
storage: &LocalFs,
#[allow(clippy::ptr_arg)]
// have to use &Utf8PathBuf due to `storage.local_path` parameter requirements
remote_storage_path: &RemotePath,
expected_metadata: Option<&StorageMetadata>,
) -> anyhow::Result<String> {
@@ -604,7 +596,7 @@ mod fs_tests {
let upload_name = "upload_1";
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
let contents = read_and_check_metadata(&storage, &upload_target, None).await?;
let contents = read_and_assert_remote_file_contents(&storage, &upload_target, None).await?;
assert_eq!(
dummy_contents(upload_name),
contents,
@@ -626,7 +618,7 @@ mod fs_tests {
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
let full_range_download_contents =
read_and_check_metadata(&storage, &upload_target, None).await?;
read_and_assert_remote_file_contents(&storage, &upload_target, None).await?;
assert_eq!(
dummy_contents(upload_name),
full_range_download_contents,
@@ -668,22 +660,6 @@ mod fs_tests {
"Second part bytes should be returned when requested"
);
let suffix_bytes = storage
.download_byte_range(&upload_target, 13, None)
.await?
.download_stream;
let suffix_bytes = aggregate(suffix_bytes).await?;
let suffix = std::str::from_utf8(&suffix_bytes)?;
assert_eq!(upload_name, suffix);
let all_bytes = storage
.download_byte_range(&upload_target, 0, None)
.await?
.download_stream;
let all_bytes = aggregate(all_bytes).await?;
let all_bytes = std::str::from_utf8(&all_bytes)?;
assert_eq!(dummy_contents("upload_1"), all_bytes);
Ok(())
}
@@ -760,7 +736,7 @@ mod fs_tests {
upload_dummy_file(&storage, upload_name, Some(metadata.clone())).await?;
let full_range_download_contents =
read_and_check_metadata(&storage, &upload_target, Some(&metadata)).await?;
read_and_assert_remote_file_contents(&storage, &upload_target, Some(&metadata)).await?;
assert_eq!(
dummy_contents(upload_name),
full_range_download_contents,
@@ -796,12 +772,12 @@ mod fs_tests {
let child = upload_dummy_file(&storage, "grandparent/parent/child", None).await?;
let uncle = upload_dummy_file(&storage, "grandparent/uncle", None).await?;
let listing = storage.list(None, ListingMode::NoDelimiter, None).await?;
let listing = storage.list(None, ListingMode::NoDelimiter).await?;
assert!(listing.prefixes.is_empty());
assert_eq!(listing.keys, [uncle.clone(), child.clone()].to_vec());
// Delimiter: should only go one deep
let listing = storage.list(None, ListingMode::WithDelimiter, None).await?;
let listing = storage.list(None, ListingMode::WithDelimiter).await?;
assert_eq!(
listing.prefixes,
@@ -814,7 +790,6 @@ mod fs_tests {
.list(
Some(&RemotePath::from_string("timelines/some_timeline/grandparent").unwrap()),
ListingMode::WithDelimiter,
None,
)
.await?;
assert_eq!(

View File

@@ -7,7 +7,6 @@
use std::{
borrow::Cow,
collections::HashMap,
num::NonZeroU32,
pin::Pin,
sync::Arc,
task::{Context, Poll},
@@ -46,9 +45,8 @@ use utils::backoff;
use super::StorageMetadata;
use crate::{
support::PermitCarrying, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode,
RemotePath, RemoteStorage, S3Config, TimeTravelError, MAX_KEYS_PER_DELETE,
REMOTE_STORAGE_PREFIX_SEPARATOR,
ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage,
S3Config, TimeTravelError, MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR,
};
pub(super) mod metrics;
@@ -65,6 +63,7 @@ pub struct S3Bucket {
concurrency_limiter: ConcurrencyLimiter,
}
#[derive(Default)]
struct GetObjectRequest {
bucket: String,
key: String,
@@ -233,8 +232,24 @@ impl S3Bucket {
let started_at = ScopeGuard::into_inner(started_at);
let object_output = match get_object {
Ok(object_output) => object_output,
match get_object {
Ok(object_output) => {
let metadata = object_output.metadata().cloned().map(StorageMetadata);
let etag = object_output.e_tag.clone();
let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok());
let body = object_output.body;
let body = ByteStreamAsStream::from(body);
let body = PermitCarrying::new(permit, body);
let body = TimedDownload::new(started_at, body);
Ok(Download {
metadata,
etag,
last_modified,
download_stream: Box::pin(body),
})
}
Err(SdkError::ServiceError(e)) if matches!(e.err(), GetObjectError::NoSuchKey(_)) => {
// Count this in the AttemptOutcome::Ok bucket, because 404 is not
// an error: we expect to sometimes fetch an object and find it missing,
@@ -244,7 +259,7 @@ impl S3Bucket {
AttemptOutcome::Ok,
started_at,
);
return Err(DownloadError::NotFound);
Err(DownloadError::NotFound)
}
Err(e) => {
metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
@@ -253,27 +268,11 @@ impl S3Bucket {
started_at,
);
return Err(DownloadError::Other(
Err(DownloadError::Other(
anyhow::Error::new(e).context("download s3 object"),
));
))
}
};
let metadata = object_output.metadata().cloned().map(StorageMetadata);
let etag = object_output.e_tag;
let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok());
let body = object_output.body;
let body = ByteStreamAsStream::from(body);
let body = PermitCarrying::new(permit, body);
let body = TimedDownload::new(started_at, body);
Ok(Download {
metadata,
etag,
last_modified,
download_stream: Box::pin(body),
})
}
}
async fn delete_oids(
@@ -355,6 +354,33 @@ impl Stream for ByteStreamAsStream {
// sense and Stream::size_hint does not really
}
pin_project_lite::pin_project! {
/// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
struct PermitCarrying<S> {
permit: tokio::sync::OwnedSemaphorePermit,
#[pin]
inner: S,
}
}
impl<S> PermitCarrying<S> {
fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
Self { permit, inner }
}
}
impl<S: Stream<Item = std::io::Result<Bytes>>> Stream for PermitCarrying<S> {
type Item = <S as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
pin_project_lite::pin_project! {
/// Times and tracks the outcome of the request.
struct TimedDownload<S> {
@@ -409,11 +435,8 @@ impl RemoteStorage for S3Bucket {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError> {
let kind = RequestKind::List;
// s3 sdk wants i32
let mut max_keys = max_keys.map(|mk| mk.get() as i32);
let mut result = Listing::default();
// get the passed prefix or if it is not set use prefix_in_bucket value
@@ -437,20 +460,13 @@ impl RemoteStorage for S3Bucket {
let _guard = self.permit(kind).await;
let started_at = start_measuring_requests(kind);
// min of two Options, returning Some if one is value and another is
// None (None is smaller than anything, so plain min doesn't work).
let request_max_keys = self
.max_keys_per_list_response
.into_iter()
.chain(max_keys.into_iter())
.min();
let mut request = self
.client
.list_objects_v2()
.bucket(self.bucket_name.clone())
.set_prefix(list_prefix.clone())
.set_continuation_token(continuation_token)
.set_max_keys(request_max_keys);
.set_max_keys(self.max_keys_per_list_response);
if let ListingMode::WithDelimiter = mode {
request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
@@ -480,14 +496,6 @@ impl RemoteStorage for S3Bucket {
let object_path = object.key().expect("response does not contain a key");
let remote_path = self.s3_object_to_relative_path(object_path);
result.keys.push(remote_path);
if let Some(mut mk) = max_keys {
assert!(mk > 0);
mk -= 1;
if mk == 0 {
return Ok(result); // limit reached
}
max_keys = Some(mk);
}
}
result.prefixes.extend(

View File

@@ -4,7 +4,6 @@
use bytes::Bytes;
use futures::stream::Stream;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Mutex;
use std::time::SystemTime;
use std::{collections::hash_map::Entry, sync::Arc};
@@ -61,7 +60,7 @@ impl UnreliableWrapper {
/// On the first attempts of this operation, return an error. After 'attempts_to_fail'
/// attempts, let the operation go ahead, and clear the counter.
///
fn attempt(&self, op: RemoteOp) -> anyhow::Result<u64> {
fn attempt(&self, op: RemoteOp) -> Result<u64, DownloadError> {
let mut attempts = self.attempts.lock().unwrap();
match attempts.entry(op) {
@@ -79,13 +78,13 @@ impl UnreliableWrapper {
} else {
let error =
anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());
Err(error)
Err(DownloadError::Other(error))
}
}
Entry::Vacant(e) => {
let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());
e.insert(1);
Err(error)
Err(DownloadError::Other(error))
}
}
}
@@ -106,30 +105,22 @@ impl RemoteStorage for UnreliableWrapper {
&self,
prefix: Option<&RemotePath>,
) -> Result<Vec<RemotePath>, DownloadError> {
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))
.map_err(DownloadError::Other)?;
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
self.inner.list_prefixes(prefix).await
}
async fn list_files(
&self,
folder: Option<&RemotePath>,
max_keys: Option<NonZeroU32>,
) -> Result<Vec<RemotePath>, DownloadError> {
self.attempt(RemoteOp::ListPrefixes(folder.cloned()))
.map_err(DownloadError::Other)?;
self.inner.list_files(folder, max_keys).await
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
self.attempt(RemoteOp::ListPrefixes(folder.cloned()))?;
self.inner.list_files(folder).await
}
async fn list(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError> {
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))
.map_err(DownloadError::Other)?;
self.inner.list(prefix, mode, max_keys).await
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
self.inner.list(prefix, mode).await
}
async fn upload(
@@ -146,8 +137,7 @@ impl RemoteStorage for UnreliableWrapper {
}
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.attempt(RemoteOp::Download(from.clone()))?;
self.inner.download(from).await
}
@@ -160,8 +150,7 @@ impl RemoteStorage for UnreliableWrapper {
// Note: We treat any download_byte_range as an "attempt" of the same
// operation. We don't pay attention to the ranges. That's good enough
// for now.
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.attempt(RemoteOp::Download(from.clone()))?;
self.inner
.download_byte_range(from, start_inclusive, end_exclusive)
.await
@@ -204,7 +193,7 @@ impl RemoteStorage for UnreliableWrapper {
cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned())))
.map_err(TimeTravelError::Other)?;
.map_err(|e| TimeTravelError::Other(anyhow::Error::new(e)))?;
self.inner
.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await

View File

@@ -1,33 +0,0 @@
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures_util::Stream;
pin_project_lite::pin_project! {
/// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
pub(crate) struct PermitCarrying<S> {
permit: tokio::sync::OwnedSemaphorePermit,
#[pin]
inner: S,
}
}
impl<S> PermitCarrying<S> {
pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
Self { permit, inner }
}
}
impl<S: Stream> Stream for PermitCarrying<S> {
type Item = <S as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}

View File

@@ -1,8 +1,8 @@
use anyhow::Context;
use camino::Utf8Path;
use remote_storage::RemotePath;
use std::collections::HashSet;
use std::sync::Arc;
use std::{collections::HashSet, num::NonZeroU32};
use test_context::test_context;
use tracing::debug;
@@ -103,7 +103,7 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a
let base_prefix =
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
let root_files = test_client
.list_files(None, None)
.list_files(None)
.await
.context("client list root files failure")?
.into_iter()
@@ -113,17 +113,8 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a
ctx.remote_blobs.clone(),
"remote storage list_files on root mismatches with the uploads."
);
// Test that max_keys limit works. In total there are about 21 files (see
// upload_simple_remote_data call in test_real_s3.rs).
let limited_root_files = test_client
.list_files(None, Some(NonZeroU32::new(2).unwrap()))
.await
.context("client list root files failure")?;
assert_eq!(limited_root_files.len(), 2);
let nested_remote_files = test_client
.list_files(Some(&base_prefix), None)
.list_files(Some(&base_prefix))
.await
.context("client list nested files failure")?
.into_iter()

View File

@@ -70,7 +70,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
}
async fn list_files(client: &Arc<GenericRemoteStorage>) -> anyhow::Result<HashSet<RemotePath>> {
Ok(retry(|| client.list_files(None, None))
Ok(retry(|| client.list_files(None))
.await
.context("list root files failure")?
.into_iter()

View File

@@ -27,11 +27,6 @@ impl Barrier {
b.wait().await
}
}
/// Return true if a call to wait() would complete immediately
pub fn is_ready(&self) -> bool {
futures::future::FutureExt::now_or_never(self.0.wait()).is_some()
}
}
impl PartialEq for Barrier {

View File

@@ -1,6 +1,6 @@
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex, MutexGuard,
Arc,
};
use tokio::sync::Semaphore;
@@ -12,7 +12,7 @@ use tokio::sync::Semaphore;
///
/// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
pub struct OnceCell<T> {
inner: Mutex<Inner<T>>,
inner: tokio::sync::RwLock<Inner<T>>,
initializers: AtomicUsize,
}
@@ -50,7 +50,7 @@ impl<T> OnceCell<T> {
let sem = Semaphore::new(1);
sem.close();
Self {
inner: Mutex::new(Inner {
inner: tokio::sync::RwLock::new(Inner {
init_semaphore: Arc::new(sem),
value: Some(value),
}),
@@ -61,18 +61,18 @@ impl<T> OnceCell<T> {
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
/// returning the guard.
///
/// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
/// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization.
///
/// Initialization is panic-safe and cancellation-safe.
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
pub async fn get_mut_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardMut<'_, T>, E>
where
F: FnOnce(InitPermit) -> Fut,
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
{
let sem = {
let guard = self.inner.lock().unwrap();
let guard = self.inner.write().await;
if guard.value.is_some() {
return Ok(Guard(guard));
return Ok(GuardMut(guard));
}
guard.init_semaphore.clone()
};
@@ -88,29 +88,72 @@ impl<T> OnceCell<T> {
let permit = InitPermit(permit);
let (value, _permit) = factory(permit).await?;
let guard = self.inner.lock().unwrap();
let guard = self.inner.write().await;
Ok(Self::set0(value, guard))
}
Err(_closed) => {
let guard = self.inner.lock().unwrap();
let guard = self.inner.write().await;
assert!(
guard.value.is_some(),
"semaphore got closed, must be initialized"
);
return Ok(Guard(guard));
return Ok(GuardMut(guard));
}
}
}
/// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
/// returning the guard.
///
/// Initialization is panic-safe and cancellation-safe.
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardRef<'_, T>, E>
where
F: FnOnce(InitPermit) -> Fut,
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
{
let sem = {
let guard = self.inner.read().await;
if guard.value.is_some() {
return Ok(GuardRef(guard));
}
guard.init_semaphore.clone()
};
let permit = {
// increment the count for the duration of queued
let _guard = CountWaitingInitializers::start(self);
sem.acquire_owned().await
};
match permit {
Ok(permit) => {
let permit = InitPermit(permit);
let (value, _permit) = factory(permit).await?;
let guard = self.inner.write().await;
Ok(Self::set0(value, guard).downgrade())
}
Err(_closed) => {
let guard = self.inner.read().await;
assert!(
guard.value.is_some(),
"semaphore got closed, must be initialized"
);
return Ok(GuardRef(guard));
}
}
}
/// Assuming a permit is held after previous call to [`GuardMut::take_and_deinit`], it can be used
/// to complete initializing the inner value.
///
/// # Panics
///
/// If the inner has already been initialized.
pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
let guard = self.inner.lock().unwrap();
pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> {
let guard = self.inner.write().await;
// cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
// give more permits right now.
@@ -122,21 +165,31 @@ impl<T> OnceCell<T> {
Self::set0(value, guard)
}
fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner<T>>) -> GuardMut<'_, T> {
if guard.value.is_some() {
drop(guard);
unreachable!("we won permit, must not be initialized");
}
guard.value = Some(value);
guard.init_semaphore.close();
Guard(guard)
GuardMut(guard)
}
/// Returns a guard to an existing initialized value, if any.
pub fn get(&self) -> Option<Guard<'_, T>> {
let guard = self.inner.lock().unwrap();
pub async fn get_mut(&self) -> Option<GuardMut<'_, T>> {
let guard = self.inner.write().await;
if guard.value.is_some() {
Some(Guard(guard))
Some(GuardMut(guard))
} else {
None
}
}
/// Returns a guard to an existing initialized value, if any.
pub async fn get(&self) -> Option<GuardRef<'_, T>> {
let guard = self.inner.read().await;
if guard.value.is_some() {
Some(GuardRef(guard))
} else {
None
}
@@ -168,9 +221,9 @@ impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
/// Uninteresting guard object to allow short-lived access to inspect or clone the held,
/// initialized value.
#[derive(Debug)]
pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
pub struct GuardMut<'a, T>(tokio::sync::RwLockWriteGuard<'a, Inner<T>>);
impl<T> std::ops::Deref for Guard<'_, T> {
impl<T> std::ops::Deref for GuardMut<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
@@ -181,7 +234,7 @@ impl<T> std::ops::Deref for Guard<'_, T> {
}
}
impl<T> std::ops::DerefMut for Guard<'_, T> {
impl<T> std::ops::DerefMut for GuardMut<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
.value
@@ -190,7 +243,7 @@ impl<T> std::ops::DerefMut for Guard<'_, T> {
}
}
impl<'a, T> Guard<'a, T> {
impl<'a, T> GuardMut<'a, T> {
/// Take the current value, and a new permit for it's deinitialization.
///
/// The permit will be on a semaphore part of the new internal value, and any following
@@ -208,6 +261,24 @@ impl<'a, T> Guard<'a, T> {
.map(|v| (v, InitPermit(permit)))
.expect("guard is not created unless value has been initialized")
}
pub fn downgrade(self) -> GuardRef<'a, T> {
GuardRef(self.0.downgrade())
}
}
#[derive(Debug)]
pub struct GuardRef<'a, T>(tokio::sync::RwLockReadGuard<'a, Inner<T>>);
impl<T> std::ops::Deref for GuardRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0
.value
.as_ref()
.expect("guard is not created unless value has been initialized")
}
}
/// Type held by OnceCell (de)initializing task.
@@ -248,7 +319,7 @@ mod tests {
barrier.wait().await;
let won = {
let g = cell
.get_or_init(|permit| {
.get_mut_or_init(|permit| {
counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
async {
counters.future_polled.fetch_add(1, Ordering::Relaxed);
@@ -295,7 +366,11 @@ mod tests {
let cell = cell.clone();
let deinitialization_started = deinitialization_started.clone();
async move {
let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
let (answer, _permit) = cell
.get_mut()
.await
.expect("initialized to value")
.take_and_deinit();
assert_eq!(answer, initial);
deinitialization_started.wait().await;
@@ -306,7 +381,7 @@ mod tests {
deinitialization_started.wait().await;
let started_at = tokio::time::Instant::now();
cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
cell.get_mut_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
.await
.unwrap();
@@ -318,21 +393,21 @@ mod tests {
jh.await.unwrap();
assert_eq!(*cell.get().unwrap(), reinit);
assert_eq!(*cell.get_mut().await.unwrap(), reinit);
}
#[test]
fn reinit_with_deinit_permit() {
#[tokio::test]
async fn reinit_with_deinit_permit() {
let cell = Arc::new(OnceCell::new(42));
let (mol, permit) = cell.get().unwrap().take_and_deinit();
cell.set(5, permit);
assert_eq!(*cell.get().unwrap(), 5);
let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit();
cell.set(5, permit).await;
assert_eq!(*cell.get_mut().await.unwrap(), 5);
let (five, permit) = cell.get().unwrap().take_and_deinit();
let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit();
assert_eq!(5, five);
cell.set(mol, permit);
assert_eq!(*cell.get().unwrap(), 42);
cell.set(mol, permit).await;
assert_eq!(*cell.get_mut().await.unwrap(), 42);
}
#[tokio::test]
@@ -340,13 +415,13 @@ mod tests {
let cell = OnceCell::default();
for _ in 0..10 {
cell.get_or_init(|_permit| async { Err("whatever error") })
cell.get_mut_or_init(|_permit| async { Err("whatever error") })
.await
.unwrap_err();
}
let g = cell
.get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
.get_mut_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
.await
.unwrap();
assert_eq!(*g, "finally success");
@@ -358,7 +433,7 @@ mod tests {
let barrier = tokio::sync::Barrier::new(2);
let initializer = cell.get_or_init(|permit| async {
let initializer = cell.get_mut_or_init(|permit| async {
barrier.wait().await;
futures::future::pending::<()>().await;
@@ -372,10 +447,10 @@ mod tests {
// now initializer is dropped
assert!(cell.get().is_none());
assert!(cell.get_mut().await.is_none());
let g = cell
.get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
.get_mut_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
.await
.unwrap();
assert_eq!(*g, "now initialized");

View File

@@ -453,12 +453,9 @@ mod tests {
event_mask: 0,
}),
expected_messages: vec![
// TODO: When updating Postgres versions, this test will cause
// problems. Postgres version in message needs updating.
//
// Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160002, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 })
// Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160001, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 })
vec![
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 76, 143, 54, 6, 60, 108, 110,
147, 188, 32, 214, 90, 130, 15, 61, 158, 76, 143, 54, 6, 60, 108, 110, 147,
188, 32, 214, 90, 130, 15, 61, 1, 0, 0, 0, 0, 0, 0, 1,

View File

@@ -56,18 +56,10 @@ pub enum ForceAwaitLogicalSize {
impl Client {
pub fn new(mgmt_api_endpoint: String, jwt: Option<&str>) -> Self {
Self::from_client(reqwest::Client::new(), mgmt_api_endpoint, jwt)
}
pub fn from_client(
client: reqwest::Client,
mgmt_api_endpoint: String,
jwt: Option<&str>,
) -> Self {
Self {
mgmt_api_endpoint,
authorization_header: jwt.map(|jwt| format!("Bearer {jwt}")),
client,
client: reqwest::Client::new(),
}
}
@@ -318,22 +310,6 @@ impl Client {
.map_err(Error::ReceiveBody)
}
pub async fn tenant_shard_split(
&self,
tenant_shard_id: TenantShardId,
req: TenantShardSplitRequest,
) -> Result<TenantShardSplitResponse> {
let uri = format!(
"{}/v1/tenant/{}/shard_split",
self.mgmt_api_endpoint, tenant_shard_id
);
self.request(Method::PUT, &uri, req)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_list(
&self,
tenant_shard_id: &TenantShardId,
@@ -363,16 +339,4 @@ impl Client {
.await
.map_err(Error::ReceiveBody)
}
pub async fn put_io_engine(
&self,
engine: &pageserver_api::models::virtual_file::IoEngineKind,
) -> Result<()> {
let uri = format!("{}/v1/io_engine", self.mgmt_api_endpoint);
self.request(Method::PUT, uri, engine)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -142,7 +142,7 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> {
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
// Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree.
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
pageserver::virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
pageserver::page_cache::init(100);
let mut total_delta_layers = 0usize;

View File

@@ -59,7 +59,7 @@ pub(crate) enum LayerCmd {
async fn read_delta_file(path: impl AsRef<Path>, ctx: &RequestContext) -> Result<()> {
let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path");
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
page_cache::init(100);
let file = FileBlockReader::new(VirtualFile::open(path).await?);
let summary_blk = file.read_blk(0, ctx).await?;
@@ -187,7 +187,7 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> {
new_tenant_id,
new_timeline_id,
} => {
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
pageserver::virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
pageserver::page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);

View File

@@ -123,7 +123,7 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> {
async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> {
// Basic initialization of things that don't change after startup
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
dump_layerfile_from_path(path, true, &ctx).await

View File

@@ -51,10 +51,6 @@ pub(crate) struct Args {
/// It doesn't get invalidated if the keyspace changes under the hood, e.g., due to new ingested data or compaction.
#[clap(long)]
keyspace_cache: Option<Utf8PathBuf>,
/// Before starting the benchmark, live-reconfigure the pageserver to use the given
/// [`pageserver_api::models::virtual_file::IoEngineKind`].
#[clap(long)]
set_io_engine: Option<pageserver_api::models::virtual_file::IoEngineKind>,
targets: Option<Vec<TenantTimelineId>>,
}
@@ -113,10 +109,6 @@ async fn main_impl(
args.pageserver_jwt.as_deref(),
));
if let Some(engine_str) = &args.set_io_engine {
mgmt_api_client.put_io_engine(engine_str).await?;
}
// discover targets
let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
&mgmt_api_client,

View File

@@ -272,12 +272,6 @@ fn start_pageserver(
);
set_build_info_metric(GIT_VERSION, BUILD_TAG);
set_launch_timestamp_metric(launch_ts);
#[cfg(target_os = "linux")]
metrics::register_internal(Box::new(metrics::more_process_metrics::Collector::new())).unwrap();
metrics::register_internal(Box::new(
pageserver::metrics::tokio_epoll_uring::Collector::new(),
))
.unwrap();
pageserver::preinitialize_metrics();
// If any failpoints were set from FAILPOINTS environment variable,

View File

@@ -623,7 +623,6 @@ impl std::fmt::Display for EvictionLayer {
}
}
#[derive(Default)]
pub(crate) struct DiskUsageEvictionInfo {
/// Timeline's largest layer (remote or resident)
pub max_layer_size: Option<u64>,
@@ -855,27 +854,19 @@ async fn collect_eviction_candidates(
let total = tenant_candidates.len();
let tenant_candidates =
tenant_candidates
.into_iter()
.enumerate()
.map(|(i, mut candidate)| {
// as we iterate this reverse sorted list, the most recently accessed layer will always
// be 1.0; this is for us to evict it last.
candidate.relative_last_activity =
eviction_order.relative_last_activity(total, i);
for (i, mut candidate) in tenant_candidates.into_iter().enumerate() {
// as we iterate this reverse sorted list, the most recently accessed layer will always
// be 1.0; this is for us to evict it last.
candidate.relative_last_activity = eviction_order.relative_last_activity(total, i);
let partition = if cumsum > min_resident_size as i128 {
MinResidentSizePartition::Above
} else {
MinResidentSizePartition::Below
};
cumsum += i128::from(candidate.layer.get_file_size());
(partition, candidate)
});
candidates.extend(tenant_candidates);
let partition = if cumsum > min_resident_size as i128 {
MinResidentSizePartition::Above
} else {
MinResidentSizePartition::Below
};
cumsum += i128::from(candidate.layer.get_file_size());
candidates.push((partition, candidate));
}
}
// Note: the same tenant ID might be hit twice, if it transitions from attached to
@@ -891,41 +882,21 @@ async fn collect_eviction_candidates(
);
for secondary_tenant in secondary_tenants {
// for secondary tenants we use a sum of on_disk layers and already evicted layers. this is
// to prevent repeated disk usage based evictions from completely draining less often
// updating secondaries.
let (mut layer_info, total_layers) = secondary_tenant.get_layers_for_eviction();
debug_assert!(
total_layers >= layer_info.resident_layers.len(),
"total_layers ({total_layers}) must be at least the resident_layers.len() ({})",
layer_info.resident_layers.len()
);
let mut layer_info = secondary_tenant.get_layers_for_eviction();
layer_info
.resident_layers
.sort_unstable_by_key(|layer_info| std::cmp::Reverse(layer_info.last_activity_ts));
let tenant_candidates =
layer_info
.resident_layers
.into_iter()
.enumerate()
.map(|(i, mut candidate)| {
candidate.relative_last_activity =
eviction_order.relative_last_activity(total_layers, i);
(
// Secondary locations' layers are always considered above the min resident size,
// i.e. secondary locations are permitted to be trimmed to zero layers if all
// the layers have sufficiently old access times.
MinResidentSizePartition::Above,
candidate,
)
});
candidates.extend(tenant_candidates);
tokio::task::yield_now().await;
candidates.extend(layer_info.resident_layers.into_iter().map(|candidate| {
(
// Secondary locations' layers are always considered above the min resident size,
// i.e. secondary locations are permitted to be trimmed to zero layers if all
// the layers have sufficiently old access times.
MinResidentSizePartition::Above,
candidate,
)
}));
}
debug_assert!(MinResidentSizePartition::Above < MinResidentSizePartition::Below,

View File

@@ -19,14 +19,11 @@ use pageserver_api::models::ShardParameters;
use pageserver_api::models::TenantDetails;
use pageserver_api::models::TenantLocationConfigResponse;
use pageserver_api::models::TenantShardLocation;
use pageserver_api::models::TenantShardSplitRequest;
use pageserver_api::models::TenantShardSplitResponse;
use pageserver_api::models::TenantState;
use pageserver_api::models::{
DownloadRemoteLayersTaskSpawnRequest, LocationConfigMode, TenantAttachRequest,
TenantLoadRequest, TenantLocationConfigRequest,
};
use pageserver_api::shard::ShardCount;
use pageserver_api::shard::TenantShardId;
use remote_storage::GenericRemoteStorage;
use remote_storage::TimeTravelError;
@@ -878,7 +875,7 @@ async fn tenant_reset_handler(
let state = get_state(&request);
state
.tenant_manager
.reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), &ctx)
.reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), ctx)
.await
.map_err(ApiError::InternalServerError)?;
@@ -1107,25 +1104,6 @@ async fn tenant_size_handler(
)
}
async fn tenant_shard_split_handler(
mut request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let req: TenantShardSplitRequest = json_request(&mut request).await?;
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
let state = get_state(&request);
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn);
let new_shards = state
.tenant_manager
.shard_split(tenant_shard_id, ShardCount(req.new_shard_count), &ctx)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, TenantShardSplitResponse { new_shards })
}
async fn layer_map_info_handler(
request: Request<Body>,
_cancel: CancellationToken,
@@ -1930,15 +1908,6 @@ async fn post_tracing_event_handler(
json_response(StatusCode::OK, ())
}
async fn put_io_engine_handler(
mut r: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let kind: crate::virtual_file::IoEngineKind = json_request(&mut r).await?;
crate::virtual_file::io_engine::set(kind);
json_response(StatusCode::OK, ())
}
/// Common functionality of all the HTTP API handlers.
///
/// - Adds a tracing span to each request (by `request_span`)
@@ -2085,9 +2054,6 @@ pub fn make_router(
.put("/v1/tenant/config", |r| {
api_handler(r, update_tenant_config_handler)
})
.put("/v1/tenant/:tenant_shard_id/shard_split", |r| {
api_handler(r, tenant_shard_split_handler)
})
.get("/v1/tenant/:tenant_shard_id/config", |r| {
api_handler(r, get_tenant_config_handler)
})
@@ -2199,6 +2165,5 @@ pub fn make_router(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/keyspace",
|r| testing_api_handler("read out the keyspace", r, timeline_collect_keyspace),
)
.put("/v1/io_engine", |r| api_handler(r, put_io_engine_handler))
.any(handler_404))
}

View File

@@ -2400,72 +2400,6 @@ impl<F: Future<Output = Result<O, E>>, O, E> Future for MeasuredRemoteOp<F> {
}
}
pub mod tokio_epoll_uring {
use metrics::UIntGauge;
pub struct Collector {
descs: Vec<metrics::core::Desc>,
systems_created: UIntGauge,
systems_destroyed: UIntGauge,
}
const NMETRICS: usize = 2;
impl metrics::core::Collector for Collector {
fn desc(&self) -> Vec<&metrics::core::Desc> {
self.descs.iter().collect()
}
fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
let mut mfs = Vec::with_capacity(NMETRICS);
let tokio_epoll_uring::metrics::Metrics {
systems_created,
systems_destroyed,
} = tokio_epoll_uring::metrics::global();
self.systems_created.set(systems_created);
mfs.extend(self.systems_created.collect());
self.systems_destroyed.set(systems_destroyed);
mfs.extend(self.systems_destroyed.collect());
mfs
}
}
impl Collector {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut descs = Vec::new();
let systems_created = UIntGauge::new(
"pageserver_tokio_epoll_uring_systems_created",
"counter of tokio-epoll-uring systems that were created",
)
.unwrap();
descs.extend(
metrics::core::Collector::desc(&systems_created)
.into_iter()
.cloned(),
);
let systems_destroyed = UIntGauge::new(
"pageserver_tokio_epoll_uring_systems_destroyed",
"counter of tokio-epoll-uring systems that were destroyed",
)
.unwrap();
descs.extend(
metrics::core::Collector::desc(&systems_destroyed)
.into_iter()
.cloned(),
);
Self {
descs,
systems_created,
systems_destroyed,
}
}
}
}
pub fn preinitialize_metrics() {
// Python tests need these and on some we do alerting.
//

View File

@@ -91,8 +91,8 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000);
/// `tokio_tar` already read the first such block. Read the second all-zeros block,
/// and check that there is no more data after the EOF marker.
///
/// 'tar' command can also write extra blocks of zeros, up to a record
/// size, controlled by the --record-size argument. Ignore them too.
/// XXX: Currently, any trailing data after the EOF marker prints a warning.
/// Perhaps it should be a hard error?
async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 512];
@@ -113,24 +113,17 @@ async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<()
anyhow::bail!("invalid tar EOF marker");
}
// Drain any extra zero-blocks after the EOF marker
// Drain any data after the EOF marker
let mut trailing_bytes = 0;
let mut seen_nonzero_bytes = false;
loop {
let nbytes = reader.read(&mut buf).await?;
trailing_bytes += nbytes;
if !buf.iter().all(|&x| x == 0) {
seen_nonzero_bytes = true;
}
if nbytes == 0 {
break;
}
}
if seen_nonzero_bytes {
anyhow::bail!("unexpected non-zero bytes after the tar archive");
}
if trailing_bytes % 512 != 0 {
anyhow::bail!("unexpected number of zeros ({trailing_bytes}), not divisible by tar block size (512 bytes), after the tar archive");
if trailing_bytes > 0 {
warn!("ignored {trailing_bytes} unexpected bytes after the tar archive");
}
Ok(())
}

View File

@@ -576,8 +576,8 @@ pub fn shutdown_token() -> CancellationToken {
/// Has the current task been requested to shut down?
pub fn is_shutdown_requested() -> bool {
if let Ok(true_or_false) = SHUTDOWN_TOKEN.try_with(|t| t.is_cancelled()) {
true_or_false
if let Ok(cancel) = SHUTDOWN_TOKEN.try_with(|t| t.clone()) {
cancel.is_cancelled()
} else {
if !cfg!(test) {
warn!("is_shutdown_requested() called in an unexpected task or thread");

View File

@@ -53,7 +53,6 @@ use self::metadata::TimelineMetadata;
use self::mgr::GetActiveTenantError;
use self::mgr::GetTenantError;
use self::mgr::TenantsMap;
use self::remote_timeline_client::upload::upload_index_part;
use self::remote_timeline_client::RemoteTimelineClient;
use self::timeline::uninit::TimelineExclusionError;
use self::timeline::uninit::TimelineUninitMark;
@@ -1377,7 +1376,7 @@ impl Tenant {
async move {
debug!("starting index part download");
let index_part = client.download_index_file(&cancel_clone).await;
let index_part = client.download_index_file(cancel_clone).await;
debug!("finished index part download");
@@ -2398,67 +2397,6 @@ impl Tenant {
pub(crate) fn get_generation(&self) -> Generation {
self.generation
}
/// This function partially shuts down the tenant (it shuts down the Timelines) and is fallible,
/// and can leave the tenant in a bad state if it fails. The caller is responsible for
/// resetting this tenant to a valid state if we fail.
pub(crate) async fn split_prepare(
&self,
child_shards: &Vec<TenantShardId>,
) -> anyhow::Result<()> {
let timelines = self.timelines.lock().unwrap().clone();
for timeline in timelines.values() {
let Some(tl_client) = &timeline.remote_client else {
anyhow::bail!("Remote storage is mandatory");
};
let Some(remote_storage) = &self.remote_storage else {
anyhow::bail!("Remote storage is mandatory");
};
// We do not block timeline creation/deletion during splits inside the pageserver: it is up to higher levels
// to ensure that they do not start a split if currently in the process of doing these.
// Upload an index from the parent: this is partly to provide freshness for the
// child tenants that will copy it, and partly for general ease-of-debugging: there will
// always be a parent shard index in the same generation as we wrote the child shard index.
tl_client.schedule_index_upload_for_file_changes()?;
tl_client.wait_completion().await?;
// Shut down the timeline's remote client: this means that the indices we write
// for child shards will not be invalidated by the parent shard deleting layers.
tl_client.shutdown().await?;
// Download methods can still be used after shutdown, as they don't flow through the remote client's
// queue. In principal the RemoteTimelineClient could provide this without downloading it, but this
// operation is rare, so it's simpler to just download it (and robustly guarantees that the index
// we use here really is the remotely persistent one).
let result = tl_client
.download_index_file(&self.cancel)
.instrument(info_span!("download_index_file", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), timeline_id=%timeline.timeline_id))
.await?;
let index_part = match result {
MaybeDeletedIndexPart::Deleted(_) => {
anyhow::bail!("Timeline deletion happened concurrently with split")
}
MaybeDeletedIndexPart::IndexPart(p) => p,
};
for child_shard in child_shards {
upload_index_part(
remote_storage,
child_shard,
&timeline.timeline_id,
self.generation,
&index_part,
&self.cancel,
)
.await?;
}
}
Ok(())
}
}
/// Given a Vec of timelines and their ancestors (timeline_id, ancestor_id),
@@ -3794,10 +3732,6 @@ impl Tenant {
Ok(())
}
pub(crate) fn get_tenant_conf(&self) -> TenantConfOpt {
self.tenant_conf.read().unwrap().tenant_conf
}
}
fn remove_timeline_and_uninit_mark(

View File

@@ -2,7 +2,6 @@
//! page server.
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
use itertools::Itertools;
use pageserver_api::key::Key;
use pageserver_api::models::ShardParameters;
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, TenantShardId};
@@ -23,7 +22,7 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use remote_storage::GenericRemoteStorage;
use utils::{completion, crashsafe};
use utils::crashsafe;
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
@@ -645,6 +644,8 @@ pub(crate) async fn shutdown_all_tenants() {
}
async fn shutdown_all_tenants0(tenants: &std::sync::RwLock<TenantsMap>) {
use utils::completion;
let mut join_set = JoinSet::new();
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
@@ -1199,7 +1200,7 @@ impl TenantManager {
&self,
tenant_shard_id: TenantShardId,
drop_cache: bool,
ctx: &RequestContext,
ctx: RequestContext,
) -> anyhow::Result<()> {
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let Some(old_slot) = slot_guard.get_old_value() else {
@@ -1252,7 +1253,7 @@ impl TenantManager {
None,
self.tenants,
SpawnMode::Normal,
ctx,
&ctx,
)?;
slot_guard.upsert(TenantSlot::Attached(tenant))?;
@@ -1374,164 +1375,6 @@ impl TenantManager {
slot_guard.revert();
result
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), new_shard_count=%new_shard_count.0))]
pub(crate) async fn shard_split(
&self,
tenant_shard_id: TenantShardId,
new_shard_count: ShardCount,
ctx: &RequestContext,
) -> anyhow::Result<Vec<TenantShardId>> {
let tenant = get_tenant(tenant_shard_id, true)?;
// Plan: identify what the new child shards will be
let effective_old_shard_count = std::cmp::max(tenant_shard_id.shard_count.0, 1);
if new_shard_count <= ShardCount(effective_old_shard_count) {
anyhow::bail!("Requested shard count is not an increase");
}
let expansion_factor = new_shard_count.0 / effective_old_shard_count;
if !expansion_factor.is_power_of_two() {
anyhow::bail!("Requested split is not a power of two");
}
let parent_shard_identity = tenant.shard_identity;
let parent_tenant_conf = tenant.get_tenant_conf();
let parent_generation = tenant.generation;
let child_shards = tenant_shard_id.split(new_shard_count);
tracing::info!(
"Shard {} splits into: {}",
tenant_shard_id.to_index(),
child_shards
.iter()
.map(|id| format!("{}", id.to_index()))
.join(",")
);
// Phase 1: Write out child shards' remote index files, in the parent tenant's current generation
if let Err(e) = tenant.split_prepare(&child_shards).await {
// If [`Tenant::split_prepare`] fails, we must reload the tenant, because it might
// have been left in a partially-shut-down state.
tracing::warn!("Failed to prepare for split: {e}, reloading Tenant before returning");
self.reset_tenant(tenant_shard_id, false, ctx).await?;
return Err(e);
}
self.resources.deletion_queue_client.flush_advisory();
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
drop(tenant);
let mut parent_slot_guard =
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let parent = match parent_slot_guard.get_old_value() {
Some(TenantSlot::Attached(t)) => t,
Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"),
Some(TenantSlot::InProgress(_)) => {
// tenant_map_acquire_slot never returns InProgress, if a slot was InProgress
// it would return an error.
unreachable!()
}
None => {
// We don't actually need the parent shard to still be attached to do our work, but it's
// a weird enough situation that the caller probably didn't want us to continue working
// if they had detached the tenant they requested the split on.
anyhow::bail!("Detached parent shard in the middle of split!")
}
};
// TODO: hardlink layers from the parent into the child shard directories so that they don't immediately re-download
// TODO: erase the dentries from the parent
// Take a snapshot of where the parent's WAL ingest had got to: we will wait for
// child shards to reach this point.
let mut target_lsns = HashMap::new();
for timeline in parent.timelines.lock().unwrap().clone().values() {
target_lsns.insert(timeline.timeline_id, timeline.get_last_record_lsn());
}
// TODO: we should have the parent shard stop its WAL ingest here, it's a waste of resources
// and could slow down the children trying to catch up.
// Phase 3: Spawn the child shards
for child_shard in &child_shards {
let mut child_shard_identity = parent_shard_identity;
child_shard_identity.count = child_shard.shard_count;
child_shard_identity.number = child_shard.shard_number;
let child_location_conf = LocationConf {
mode: LocationMode::Attached(AttachedLocationConfig {
generation: parent_generation,
attach_mode: AttachmentMode::Single,
}),
shard: child_shard_identity,
tenant_conf: parent_tenant_conf,
};
self.upsert_location(
*child_shard,
child_location_conf,
None,
SpawnMode::Normal,
ctx,
)
.await?;
}
// Phase 4: wait for child chards WAL ingest to catch up to target LSN
for child_shard_id in &child_shards {
let child_shard = {
let locked = TENANTS.read().unwrap();
let peek_slot =
tenant_map_peek_slot(&locked, child_shard_id, TenantSlotPeekMode::Read)?;
peek_slot.and_then(|s| s.get_attached()).cloned()
};
if let Some(t) = child_shard {
let timelines = t.timelines.lock().unwrap().clone();
for timeline in timelines.values() {
let Some(target_lsn) = target_lsns.get(&timeline.timeline_id) else {
continue;
};
tracing::info!(
"Waiting for child shard {}/{} to reach target lsn {}...",
child_shard_id,
timeline.timeline_id,
target_lsn
);
if let Err(e) = timeline.wait_lsn(*target_lsn, ctx).await {
// Failure here might mean shutdown, in any case this part is an optimization
// and we shouldn't hold up the split operation.
tracing::warn!(
"Failed to wait for timeline {} to reach lsn {target_lsn}: {e}",
timeline.timeline_id
);
} else {
tracing::info!(
"Child shard {}/{} reached target lsn {}",
child_shard_id,
timeline.timeline_id,
target_lsn
);
}
}
}
}
// Phase 5: Shut down the parent shard.
let (_guard, progress) = completion::channel();
match parent.shutdown(progress, false).await {
Ok(()) => {}
Err(other) => {
other.wait().await;
}
}
parent_slot_guard.drop_old_value()?;
// Phase 6: Release the InProgress on the parent shard
drop(parent_slot_guard);
Ok(child_shards)
}
}
#[derive(Debug, thiserror::Error)]
@@ -2366,6 +2209,8 @@ async fn remove_tenant_from_memory<V, F>(
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
use utils::completion;
let mut slot_guard =
tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?;

View File

@@ -217,7 +217,6 @@ use crate::metrics::{
};
use crate::task_mgr::shutdown_token;
use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id;
use crate::tenant::remote_timeline_client::download::download_retry;
use crate::tenant::storage_layer::AsLayerDesc;
use crate::tenant::upload_queue::Delete;
use crate::tenant::TIMELINES_SEGMENT_NAME;
@@ -263,11 +262,6 @@ pub(crate) const INITDB_PRESERVED_PATH: &str = "initdb-preserved.tar.zst";
/// Default buffer size when interfacing with [`tokio::fs::File`].
pub(crate) const BUFFER_SIZE: usize = 32 * 1024;
/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not
/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that.
pub(crate) const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120);
pub(crate) const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120);
pub enum MaybeDeletedIndexPart {
IndexPart(IndexPart),
Deleted(IndexPart),
@@ -331,6 +325,11 @@ pub struct RemoteTimelineClient {
cancel: CancellationToken,
}
/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not
/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that.
const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120);
const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120);
/// Wrapper for timeout_cancellable that flattens result and converts TimeoutCancellableError to anyhow.
///
/// This is a convenience for the various upload functions. In future
@@ -507,7 +506,7 @@ impl RemoteTimelineClient {
/// Download index file
pub async fn download_index_file(
&self,
cancel: &CancellationToken,
cancel: CancellationToken,
) -> Result<MaybeDeletedIndexPart, DownloadError> {
let _unfinished_gauge_guard = self.metrics.call_begin(
&RemoteOpFileKind::Index,
@@ -1148,17 +1147,22 @@ impl RemoteTimelineClient {
let cancel = shutdown_token();
let remaining = download_retry(
let remaining = backoff::retry(
|| async {
self.storage_impl
.list_files(Some(&timeline_storage_path), None)
.list_files(Some(&timeline_storage_path))
.await
},
"list remaining files",
|_e| false,
FAILED_DOWNLOAD_WARN_THRESHOLD,
FAILED_REMOTE_OP_RETRIES,
"list_prefixes",
&cancel,
)
.await
.context("list files remaining files")?;
.ok_or_else(|| anyhow::anyhow!("Cancelled!"))
.and_then(|x| x)
.context("list prefixes")?;
// We will delete the current index_part object last, since it acts as a deletion
// marker via its deleted_at attribute
@@ -1347,7 +1351,6 @@ impl RemoteTimelineClient {
/// queue.
///
async fn perform_upload_task(self: &Arc<Self>, task: Arc<UploadTask>) {
let cancel = shutdown_token();
// Loop to retry until it completes.
loop {
// If we're requested to shut down, close up shop and exit.
@@ -1359,7 +1362,7 @@ impl RemoteTimelineClient {
// the Future, but we're not 100% sure if the remote storage library
// is cancellation safe, so we don't dare to do that. Hopefully, the
// upload finishes or times out soon enough.
if cancel.is_cancelled() {
if task_mgr::is_shutdown_requested() {
info!("upload task cancelled by shutdown request");
match self.stop() {
Ok(()) => {}
@@ -1470,7 +1473,7 @@ impl RemoteTimelineClient {
retries,
DEFAULT_BASE_BACKOFF_SECONDS,
DEFAULT_MAX_BACKOFF_SECONDS,
&cancel,
&shutdown_token(),
)
.await;
}
@@ -1987,7 +1990,7 @@ mod tests {
// Download back the index.json, and check that the list of files is correct
let initial_index_part = match client
.download_index_file(&CancellationToken::new())
.download_index_file(CancellationToken::new())
.await
.unwrap()
{
@@ -2081,7 +2084,7 @@ mod tests {
// Download back the index.json, and check that the list of files is correct
let index_part = match client
.download_index_file(&CancellationToken::new())
.download_index_file(CancellationToken::new())
.await
.unwrap()
{
@@ -2283,7 +2286,7 @@ mod tests {
let client = test_state.build_client(get_generation);
let download_r = client
.download_index_file(&CancellationToken::new())
.download_index_file(CancellationToken::new())
.await
.expect("download should always succeed");
assert!(matches!(download_r, MaybeDeletedIndexPart::IndexPart(_)));

View File

@@ -216,15 +216,16 @@ pub async fn list_remote_timelines(
anyhow::bail!("storage-sync-list-remote-timelines");
});
let cancel_inner = cancel.clone();
let listing = download_retry_forever(
|| {
download_cancellable(
&cancel,
storage.list(Some(&remote_path), ListingMode::WithDelimiter, None),
&cancel_inner,
storage.list(Some(&remote_path), ListingMode::WithDelimiter),
)
},
&format!("list timelines for {tenant_shard_id}"),
&cancel,
cancel,
)
.await?;
@@ -257,18 +258,19 @@ async fn do_download_index_part(
tenant_shard_id: &TenantShardId,
timeline_id: &TimelineId,
index_generation: Generation,
cancel: &CancellationToken,
cancel: CancellationToken,
) -> Result<IndexPart, DownloadError> {
use futures::stream::StreamExt;
let remote_path = remote_index_path(tenant_shard_id, timeline_id, index_generation);
let cancel_inner = cancel.clone();
let index_part_bytes = download_retry_forever(
|| async {
// Cancellation: if is safe to cancel this future because we're just downloading into
// a memory buffer, not touching local disk.
let index_part_download =
download_cancellable(cancel, storage.download(&remote_path)).await?;
download_cancellable(&cancel_inner, storage.download(&remote_path)).await?;
let mut index_part_bytes = Vec::new();
let mut stream = std::pin::pin!(index_part_download.download_stream);
@@ -286,7 +288,7 @@ async fn do_download_index_part(
.await?;
let index_part: IndexPart = serde_json::from_slice(&index_part_bytes)
.with_context(|| format!("deserialize index part file at {remote_path:?}"))
.with_context(|| format!("download index part file at {remote_path:?}"))
.map_err(DownloadError::Other)?;
Ok(index_part)
@@ -303,7 +305,7 @@ pub(super) async fn download_index_part(
tenant_shard_id: &TenantShardId,
timeline_id: &TimelineId,
my_generation: Generation,
cancel: &CancellationToken,
cancel: CancellationToken,
) -> Result<IndexPart, DownloadError> {
debug_assert_current_span_has_tenant_and_timeline_id();
@@ -323,8 +325,14 @@ pub(super) async fn download_index_part(
// index in our generation.
//
// This is an optimization to avoid doing the listing for the general case below.
let res =
do_download_index_part(storage, tenant_shard_id, timeline_id, my_generation, cancel).await;
let res = do_download_index_part(
storage,
tenant_shard_id,
timeline_id,
my_generation,
cancel.clone(),
)
.await;
match res {
Ok(index_part) => {
tracing::debug!(
@@ -349,7 +357,7 @@ pub(super) async fn download_index_part(
tenant_shard_id,
timeline_id,
my_generation.previous(),
cancel,
cancel.clone(),
)
.await;
match res {
@@ -371,13 +379,18 @@ pub(super) async fn download_index_part(
// objects, and select the highest one with a generation <= my_generation. Constructing the prefix is equivalent
// to constructing a full index path with no generation, because the generation is a suffix.
let index_prefix = remote_index_path(tenant_shard_id, timeline_id, Generation::none());
let indices = download_retry(
|| async { storage.list_files(Some(&index_prefix), None).await },
"list index_part files",
cancel,
let indices = backoff::retry(
|| async { storage.list_files(Some(&index_prefix)).await },
|_| false,
FAILED_DOWNLOAD_WARN_THRESHOLD,
FAILED_REMOTE_OP_RETRIES,
"listing index_part files",
&cancel,
)
.await?;
.await
.ok_or_else(|| anyhow::anyhow!("Cancelled"))
.and_then(|x| x)
.map_err(DownloadError::Other)?;
// General case logic for which index to use: the latest index whose generation
// is <= our own. See "Finding the remote indices for timelines" in docs/rfcs/025-generation-numbers.md
@@ -434,6 +447,8 @@ pub(crate) async fn download_initdb_tar_zst(
"{INITDB_PATH}.download-{timeline_id}.{TEMP_FILE_SUFFIX}"
));
let cancel_inner = cancel.clone();
let file = download_retry(
|| async {
let file = OpenOptions::new()
@@ -446,11 +461,13 @@ pub(crate) async fn download_initdb_tar_zst(
.with_context(|| format!("tempfile creation {temp_path}"))
.map_err(DownloadError::Other)?;
let download = match download_cancellable(cancel, storage.download(&remote_path)).await
let download = match download_cancellable(&cancel_inner, storage.download(&remote_path))
.await
{
Ok(dl) => dl,
Err(DownloadError::NotFound) => {
download_cancellable(cancel, storage.download(&remote_preserved_path)).await?
download_cancellable(&cancel_inner, storage.download(&remote_preserved_path))
.await?
}
Err(other) => Err(other)?,
};
@@ -499,7 +516,7 @@ pub(crate) async fn download_initdb_tar_zst(
/// with backoff.
///
/// (See similar logic for uploads in `perform_upload_task`)
pub(super) async fn download_retry<T, O, F>(
async fn download_retry<T, O, F>(
op: O,
description: &str,
cancel: &CancellationToken,
@@ -510,7 +527,7 @@ where
{
backoff::retry(
op,
DownloadError::is_permanent,
|e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound),
FAILED_DOWNLOAD_WARN_THRESHOLD,
FAILED_REMOTE_OP_RETRIES,
description,
@@ -524,7 +541,7 @@ where
async fn download_retry_forever<T, O, F>(
op: O,
description: &str,
cancel: &CancellationToken,
cancel: CancellationToken,
) -> Result<T, DownloadError>
where
O: FnMut() -> F,
@@ -532,11 +549,11 @@ where
{
backoff::retry(
op,
DownloadError::is_permanent,
|e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound),
FAILED_DOWNLOAD_WARN_THRESHOLD,
u32::MAX,
description,
cancel,
&cancel,
)
.await
.ok_or_else(|| DownloadError::Cancelled)

View File

@@ -27,7 +27,7 @@ use super::index::LayerFileMetadata;
use tracing::info;
/// Serializes and uploads the given index part data to the remote storage.
pub(crate) async fn upload_index_part<'a>(
pub(super) async fn upload_index_part<'a>(
storage: &'a GenericRemoteStorage,
tenant_shard_id: &TenantShardId,
timeline_id: &TimelineId,

View File

@@ -160,7 +160,7 @@ impl SecondaryTenant {
&self.tenant_shard_id
}
pub(crate) fn get_layers_for_eviction(self: &Arc<Self>) -> (DiskUsageEvictionInfo, usize) {
pub(crate) fn get_layers_for_eviction(self: &Arc<Self>) -> DiskUsageEvictionInfo {
self.detail.lock().unwrap().get_layers_for_eviction(self)
}

View File

@@ -146,15 +146,14 @@ impl SecondaryDetail {
}
}
/// Additionally returns the total number of layers, used for more stable relative access time
/// based eviction.
pub(super) fn get_layers_for_eviction(
&self,
parent: &Arc<SecondaryTenant>,
) -> (DiskUsageEvictionInfo, usize) {
let mut result = DiskUsageEvictionInfo::default();
let mut total_layers = 0;
) -> DiskUsageEvictionInfo {
let mut result = DiskUsageEvictionInfo {
max_layer_size: None,
resident_layers: Vec::new(),
};
for (timeline_id, timeline_detail) in &self.timelines {
result
.resident_layers
@@ -170,10 +169,6 @@ impl SecondaryDetail {
relative_last_activity: finite_f32::FiniteF32::ZERO,
}
}));
// total might be missing currently downloading layers, but as a lower than actual
// value it is good enough approximation.
total_layers += timeline_detail.on_disk_layers.len() + timeline_detail.evicted_at.len();
}
result.max_layer_size = result
.resident_layers
@@ -188,7 +183,7 @@ impl SecondaryDetail {
result.resident_layers.len()
);
(result, total_layers)
result
}
}
@@ -317,7 +312,9 @@ impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCo
.tenant_manager
.get_secondary_tenant_shard(*tenant_shard_id);
let Some(tenant) = tenant else {
return Err(anyhow::anyhow!("Not found or not in Secondary mode"));
{
return Err(anyhow::anyhow!("Not found or not in Secondary mode"));
}
};
Ok(PendingDownload {
@@ -392,9 +389,9 @@ impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCo
}
CompleteDownload {
secondary_state,
completed_at: Instant::now(),
}
secondary_state,
completed_at: Instant::now(),
}
}.instrument(info_span!(parent: None, "secondary_download", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))))
}
}
@@ -533,7 +530,7 @@ impl<'a> TenantDownloader<'a> {
.map_err(UpdateError::from)?;
let mut heatmap_bytes = Vec::new();
let mut body = tokio_util::io::StreamReader::new(download.download_stream);
let _size = tokio::io::copy_buf(&mut body, &mut heatmap_bytes).await?;
let _size = tokio::io::copy(&mut body, &mut heatmap_bytes).await?;
Ok(heatmap_bytes)
},
|e| matches!(e, UpdateError::NoData | UpdateError::Cancelled),

View File

@@ -300,8 +300,8 @@ impl Layer {
})
}
pub(crate) fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
self.0.info(reset)
pub(crate) async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
self.0.info(reset).await
}
pub(crate) fn access_stats(&self) -> &LayerAccessStats {
@@ -612,10 +612,10 @@ impl LayerInner {
let mut rx = self.status.subscribe();
let strong = {
match self.inner.get() {
match self.inner.get_mut().await {
Some(mut either) => {
self.wanted_evicted.store(true, Ordering::Relaxed);
either.downgrade()
ResidentOrWantedEvicted::downgrade(&mut either)
}
None => return Err(EvictionError::NotFound),
}
@@ -641,7 +641,7 @@ impl LayerInner {
// use however late (compared to the initial expressing of wanted) as the
// "outcome" now
LAYER_IMPL_METRICS.inc_broadcast_lagged();
match self.inner.get() {
match self.inner.get_mut().await {
Some(_) => Err(EvictionError::Downloaded),
None => Ok(()),
}
@@ -759,7 +759,7 @@ impl LayerInner {
// use the already held initialization permit because it is impossible to hit the
// below paths anymore essentially limiting the max loop iterations to 2.
let (value, init_permit) = download(init_permit).await?;
let mut guard = self.inner.set(value, init_permit);
let mut guard = self.inner.set(value, init_permit).await;
let (strong, _upgraded) = guard
.get_and_upgrade()
.expect("init creates strong reference, we held the init permit");
@@ -767,7 +767,7 @@ impl LayerInner {
}
let (weak, permit) = {
let mut locked = self.inner.get_or_init(download).await?;
let mut locked = self.inner.get_mut_or_init(download).await?;
if let Some((strong, upgraded)) = locked.get_and_upgrade() {
if upgraded {
@@ -989,12 +989,12 @@ impl LayerInner {
}
}
fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
let layer_file_name = self.desc.filename().file_name();
// this is not accurate: we could have the file locally but there was a cancellation
// and now we are not in sync, or we are currently downloading it.
let remote = self.inner.get().is_none();
let remote = self.inner.get_mut().await.is_none();
let access_stats = self.access_stats.as_api_model(reset);
@@ -1053,7 +1053,7 @@ impl LayerInner {
LAYER_IMPL_METRICS.inc_eviction_cancelled(EvictionCancelled::LayerGone);
return;
};
match this.evict_blocking(version) {
match tokio::runtime::Handle::current().block_on(this.evict_blocking(version)) {
Ok(()) => LAYER_IMPL_METRICS.inc_completed_evictions(),
Err(reason) => LAYER_IMPL_METRICS.inc_eviction_cancelled(reason),
}
@@ -1061,7 +1061,7 @@ impl LayerInner {
}
}
fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> {
async fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> {
// deleted or detached timeline, don't do anything.
let Some(timeline) = self.timeline.upgrade() else {
return Err(EvictionCancelled::TimelineGone);
@@ -1070,7 +1070,7 @@ impl LayerInner {
// to avoid starting a new download while we evict, keep holding on to the
// permit.
let _permit = {
let maybe_downloaded = self.inner.get();
let maybe_downloaded = self.inner.get_mut().await;
let (_weak, permit) = match maybe_downloaded {
Some(mut guard) => {

View File

@@ -1268,7 +1268,7 @@ impl Timeline {
let mut historic_layers = Vec::new();
for historic_layer in layer_map.iter_historic_layers() {
let historic_layer = guard.get_from_desc(&historic_layer);
historic_layers.push(historic_layer.info(reset));
historic_layers.push(historic_layer.info(reset).await);
}
LayerMapInfo {

View File

@@ -343,23 +343,6 @@ pub(super) async fn handle_walreceiver_connection(
modification.commit(&ctx).await?;
uncommitted_records = 0;
filtered_records = 0;
//
// We should check checkpoint distance after appending each ingest_batch_size bytes because otherwise
// layer size can become much larger than `checkpoint_distance`.
// It can append because wal-sender is sending WAL using 125kb chucks and some WAL records can cause writing large
// amount of data to key-value storage. So performing this check only after processing
// all WAL records in the chunk, can cause huge L0 layer files.
//
timeline
.check_checkpoint_distance()
.await
.with_context(|| {
format!(
"Failed to check checkpoint distance for timeline {}",
timeline.timeline_id
)
})?;
}
}

View File

@@ -28,10 +28,9 @@ use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::time::Instant;
use utils::fs_ext;
pub use pageserver_api::models::virtual_file as api;
pub(crate) mod io_engine;
mod io_engine;
mod open_options;
pub(crate) use io_engine::IoEngineKind;
pub use io_engine::IoEngineKind;
pub(crate) use open_options::*;
///

View File

@@ -7,100 +7,67 @@
//!
//! Then use [`get`] and [`super::OpenOptions`].
pub(crate) use super::api::IoEngineKind;
#[derive(Clone, Copy)]
#[repr(u8)]
pub(crate) enum IoEngine {
NotSet,
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Hash,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
Debug,
)]
#[strum(serialize_all = "kebab-case")]
pub enum IoEngineKind {
StdFs,
#[cfg(target_os = "linux")]
TokioEpollUring,
}
impl From<IoEngineKind> for IoEngine {
fn from(value: IoEngineKind) -> Self {
match value {
IoEngineKind::StdFs => IoEngine::StdFs,
#[cfg(target_os = "linux")]
IoEngineKind::TokioEpollUring => IoEngine::TokioEpollUring,
}
}
}
impl TryFrom<u8> for IoEngine {
type Error = u8;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
v if v == (IoEngine::NotSet as u8) => IoEngine::NotSet,
v if v == (IoEngine::StdFs as u8) => IoEngine::StdFs,
#[cfg(target_os = "linux")]
v if v == (IoEngine::TokioEpollUring as u8) => IoEngine::TokioEpollUring,
x => return Err(x),
})
}
}
static IO_ENGINE: AtomicU8 = AtomicU8::new(IoEngine::NotSet as u8);
pub(crate) fn set(engine_kind: IoEngineKind) {
let engine: IoEngine = engine_kind.into();
IO_ENGINE.store(engine as u8, std::sync::atomic::Ordering::Relaxed);
#[cfg(not(test))]
{
let metric = &crate::metrics::virtual_file_io_engine::KIND;
metric.reset();
metric
.with_label_values(&[&format!("{engine_kind}")])
.set(1);
}
}
static IO_ENGINE: once_cell::sync::OnceCell<IoEngineKind> = once_cell::sync::OnceCell::new();
#[cfg(not(test))]
pub(super) fn init(engine_kind: IoEngineKind) {
set(engine_kind);
}
pub(super) fn get() -> IoEngine {
let cur = IoEngine::try_from(IO_ENGINE.load(Ordering::Relaxed)).unwrap();
if cfg!(test) {
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE";
match cur {
IoEngine::NotSet => {
let kind = match std::env::var(env_var_name) {
Ok(v) => match v.parse::<IoEngineKind>() {
Ok(engine_kind) => engine_kind,
Err(e) => {
panic!("invalid VirtualFile io engine for env var {env_var_name}: {e:#}: {v:?}")
}
},
Err(std::env::VarError::NotPresent) => {
crate::config::defaults::DEFAULT_VIRTUAL_FILE_IO_ENGINE
.parse()
.unwrap()
}
Err(std::env::VarError::NotUnicode(_)) => {
panic!("env var {env_var_name} is not unicode");
}
};
self::set(kind);
self::get()
}
x => x,
}
} else {
cur
pub(super) fn init(engine: IoEngineKind) {
if IO_ENGINE.set(engine).is_err() {
panic!("called twice");
}
crate::metrics::virtual_file_io_engine::KIND
.with_label_values(&[&format!("{engine}")])
.set(1);
}
use std::{
os::unix::prelude::FileExt,
sync::atomic::{AtomicU8, Ordering},
};
pub(super) fn get() -> &'static IoEngineKind {
#[cfg(test)]
{
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE";
IO_ENGINE.get_or_init(|| match std::env::var(env_var_name) {
Ok(v) => match v.parse::<IoEngineKind>() {
Ok(engine_kind) => engine_kind,
Err(e) => {
panic!("invalid VirtualFile io engine for env var {env_var_name}: {e:#}: {v:?}")
}
},
Err(std::env::VarError::NotPresent) => {
crate::config::defaults::DEFAULT_VIRTUAL_FILE_IO_ENGINE
.parse()
.unwrap()
}
Err(std::env::VarError::NotUnicode(_)) => {
panic!("env var {env_var_name} is not unicode");
}
})
}
#[cfg(not(test))]
IO_ENGINE.get().unwrap()
}
use std::os::unix::prelude::FileExt;
use super::FileGuard;
impl IoEngine {
impl IoEngineKind {
pub(super) async fn read_at<B>(
&self,
file_guard: FileGuard,
@@ -111,8 +78,7 @@ impl IoEngine {
B: tokio_epoll_uring::BoundedBufMut + Send,
{
match self {
IoEngine::NotSet => panic!("not initialized"),
IoEngine::StdFs => {
IoEngineKind::StdFs => {
// SAFETY: `dst` only lives at most as long as this match arm, during which buf remains valid memory.
let dst = unsafe {
std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total())
@@ -130,7 +96,7 @@ impl IoEngine {
((file_guard, buf), res)
}
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
IoEngineKind::TokioEpollUring => {
let system = tokio_epoll_uring::thread_local_system().await;
let (resources, res) = system.read(file_guard, offset, buf).await;
(

View File

@@ -1,6 +1,6 @@
//! Enum-dispatch to the `OpenOptions` type of the respective [`super::IoEngineKind`];
use super::io_engine::IoEngine;
use super::IoEngineKind;
use std::{os::fd::OwnedFd, path::Path};
#[derive(Debug, Clone)]
@@ -13,10 +13,9 @@ pub enum OpenOptions {
impl Default for OpenOptions {
fn default() -> Self {
match super::io_engine::get() {
IoEngine::NotSet => panic!("io engine not set"),
IoEngine::StdFs => Self::StdFs(std::fs::OpenOptions::new()),
IoEngineKind::StdFs => Self::StdFs(std::fs::OpenOptions::new()),
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
IoEngineKind::TokioEpollUring => {
Self::TokioEpollUring(tokio_epoll_uring::ops::open_at::OpenOptions::new())
}
}

View File

@@ -314,9 +314,6 @@ lfc_change_limit_hook(int newval, void *extra)
lfc_ctl->used -= 1;
}
lfc_ctl->limit = new_size;
if (new_size == 0) {
lfc_ctl->generation += 1;
}
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);

View File

@@ -11,23 +11,16 @@
#include "postgres.h"
#include "fmgr.h"
#include "miscadmin.h"
#include "access/xact.h"
#include "access/xlog.h"
#include "storage/buf_internals.h"
#include "storage/bufmgr.h"
#include "catalog/pg_type.h"
#include "postmaster/bgworker.h"
#include "postmaster/interrupt.h"
#include "replication/slot.h"
#include "replication/walsender.h"
#include "storage/procsignal.h"
#include "tcop/tcopprot.h"
#include "funcapi.h"
#include "access/htup_details.h"
#include "utils/pg_lsn.h"
#include "utils/guc.h"
#include "utils/wait_event.h"
#include "neon.h"
#include "walproposer.h"
@@ -37,130 +30,6 @@
PG_MODULE_MAGIC;
void _PG_init(void);
static int logical_replication_max_time_lag = 3600;
static void
InitLogicalReplicationMonitor(void)
{
BackgroundWorker bgw;
DefineCustomIntVariable(
"neon.logical_replication_max_time_lag",
"Threshold for dropping unused logical replication slots",
NULL,
&logical_replication_max_time_lag,
3600, 0, INT_MAX,
PGC_SIGHUP,
GUC_UNIT_S,
NULL, NULL, NULL);
memset(&bgw, 0, sizeof(bgw));
bgw.bgw_flags = BGWORKER_SHMEM_ACCESS;
bgw.bgw_start_time = BgWorkerStart_RecoveryFinished;
snprintf(bgw.bgw_library_name, BGW_MAXLEN, "neon");
snprintf(bgw.bgw_function_name, BGW_MAXLEN, "LogicalSlotsMonitorMain");
snprintf(bgw.bgw_name, BGW_MAXLEN, "Logical replication monitor");
snprintf(bgw.bgw_type, BGW_MAXLEN, "Logical replication monitor");
bgw.bgw_restart_time = 5;
bgw.bgw_notify_pid = 0;
bgw.bgw_main_arg = (Datum) 0;
RegisterBackgroundWorker(&bgw);
}
typedef struct
{
NameData name;
bool dropped;
XLogRecPtr confirmed_flush_lsn;
TimestampTz last_updated;
} SlotStatus;
/*
* Unused logical replication slots pins WAL and prevents deletion of snapshots.
*/
PGDLLEXPORT void
LogicalSlotsMonitorMain(Datum main_arg)
{
SlotStatus* slots;
TimestampTz now, last_checked;
/* Establish signal handlers. */
pqsignal(SIGUSR1, procsignal_sigusr1_handler);
pqsignal(SIGHUP, SignalHandlerForConfigReload);
pqsignal(SIGTERM, die);
BackgroundWorkerUnblockSignals();
slots = (SlotStatus*)calloc(max_replication_slots, sizeof(SlotStatus));
last_checked = GetCurrentTimestamp();
for (;;)
{
(void) WaitLatch(MyLatch,
WL_LATCH_SET | WL_EXIT_ON_PM_DEATH | WL_TIMEOUT,
logical_replication_max_time_lag*1000/2,
PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
CHECK_FOR_INTERRUPTS();
now = GetCurrentTimestamp();
if (now - last_checked > logical_replication_max_time_lag*USECS_PER_SEC)
{
int n_active_slots = 0;
last_checked = now;
LWLockAcquire(ReplicationSlotControlLock, LW_SHARED);
for (int i = 0; i < max_replication_slots; i++)
{
ReplicationSlot *s = &ReplicationSlotCtl->replication_slots[i];
/* Consider only logical repliction slots */
if (!s->in_use || !SlotIsLogical(s))
continue;
if (s->active_pid != 0)
{
n_active_slots += 1;
continue;
}
/* Check if there was some activity with the slot since last check */
if (s->data.confirmed_flush != slots[i].confirmed_flush_lsn)
{
slots[i].confirmed_flush_lsn = s->data.confirmed_flush;
slots[i].last_updated = now;
}
else if (now - slots[i].last_updated > logical_replication_max_time_lag*USECS_PER_SEC)
{
slots[i].name = s->data.name;
slots[i].dropped = true;
}
}
LWLockRelease(ReplicationSlotControlLock);
/*
* If there are no active subscriptions, then no new snapshots are generated
* and so no need to force slot deletion.
*/
if (n_active_slots != 0)
{
for (int i = 0; i < max_replication_slots; i++)
{
if (slots[i].dropped)
{
elog(LOG, "Drop logical replication slot because it was not update more than %ld seconds",
(now - slots[i].last_updated)/USECS_PER_SEC);
ReplicationSlotDrop(slots[i].name.data, true);
slots[i].dropped = false;
}
}
}
}
}
}
void
_PG_init(void)
{
@@ -175,8 +44,6 @@ _PG_init(void)
pg_init_libpagestore();
pg_init_walproposer();
InitLogicalReplicationMonitor();
InitControlPlaneConnector();
pg_init_extension_server();

View File

@@ -19,7 +19,6 @@ chrono.workspace = true
clap.workspace = true
consumption_metrics.workspace = true
dashmap.workspace = true
env_logger.workspace = true
futures.workspace = true
git-version.workspace = true
hashbrown.workspace = true
@@ -60,8 +59,6 @@ scopeguard.workspace = true
serde.workspace = true
serde_json.workspace = true
sha2.workspace = true
smol_str.workspace = true
smallvec.workspace = true
socket2.workspace = true
sync_wrapper.workspace = true
task-local-extensions.workspace = true
@@ -78,7 +75,6 @@ tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true
url.workspace = true
urlencoding.workspace = true
utils.workspace = true
uuid.workspace = true
webpki-roots.workspace = true
@@ -87,6 +83,7 @@ native-tls.workspace = true
postgres-native-tls.workspace = true
postgres-protocol.workspace = true
redis.workspace = true
smol_str.workspace = true
workspace_hack.workspace = true

View File

@@ -5,8 +5,7 @@ pub use backend::BackendType;
mod credentials;
pub use credentials::{
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
ComputeUserInfoParseError, IpPattern,
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
};
mod password_hack;
@@ -15,12 +14,8 @@ use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
use tokio::time::error::Elapsed;
use crate::{
console,
error::{ReportableError, UserFacingError},
};
use crate::{console, error::UserFacingError};
use std::io;
use thiserror::Error;
@@ -36,6 +31,9 @@ pub enum AuthErrorImpl {
#[error(transparent)]
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] console::errors::WakeComputeError),
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
@@ -69,9 +67,6 @@ pub enum AuthErrorImpl {
#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,
#[error("Authentication timed out")]
UserTimeout(Elapsed),
}
#[derive(Debug, Error)]
@@ -98,10 +93,6 @@ impl AuthError {
pub fn is_auth_failed(&self) -> bool {
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
}
pub fn user_timeout(elapsed: Elapsed) -> Self {
AuthErrorImpl::UserTimeout(elapsed).into()
}
}
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
@@ -116,6 +107,7 @@ impl UserFacingError for AuthError {
match self.0.as_ref() {
Link(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
WakeCompute(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
AuthFailed(_) => self.to_string(),
BadAuthMethod(_) => self.to_string(),
@@ -124,26 +116,6 @@ impl UserFacingError for AuthError {
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed => self.to_string(),
TooManyConnections => self.to_string(),
UserTimeout(_) => self.to_string(),
}
}
}
impl ReportableError for AuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.get_error_kind(),
GetAuthInfo(e) => e.get_error_kind(),
Sasl(e) => e.get_error_kind(),
AuthFailed(_) => crate::error::ErrorKind::User,
BadAuthMethod(_) => crate::error::ErrorKind::User,
MalformedPassword(_) => crate::error::ErrorKind::User,
MissingEndpointName => crate::error::ErrorKind::User,
Io(_) => crate::error::ErrorKind::ClientDisconnect,
IpAddressNotAllowed => crate::error::ErrorKind::User,
TooManyConnections => crate::error::ErrorKind::RateLimit,
UserTimeout(_) => crate::error::ErrorKind::User,
}
}
}

View File

@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
use crate::cache::Cached;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::{AuthSecret, NodeInfo};
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::wake_compute::wake_compute;
use crate::proxy::NeonOptions;
use crate::stream::Stream;
use crate::{
@@ -26,6 +26,7 @@ use crate::{
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -55,11 +56,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum BackendType<'a, T, D> {
pub enum BackendType<'a, T> {
/// Cloud API (V2).
Console(MaybeOwned<'a, ConsoleBackend>, T),
/// Authentication via a web browser.
Link(MaybeOwned<'a, url::ApiUrl>, D),
Link(MaybeOwned<'a, url::ApiUrl>),
}
pub trait TestBackend: Send + Sync + 'static {
@@ -67,10 +68,9 @@ pub trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, (), ()> {
impl std::fmt::Display for BackendType<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
@@ -85,50 +85,51 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
}
impl<T, D> BackendType<'_, T, D> {
impl<T> BackendType<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
pub fn as_ref(&self) -> BackendType<'_, &T> {
use BackendType::*;
match self {
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
Link(c) => Link(MaybeOwned::Borrowed(c)),
}
}
}
impl<'a, T, D> BackendType<'a, T, D> {
impl<'a, T> BackendType<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
use BackendType::*;
match self {
Console(c, x) => Console(c, f(x)),
Link(c, x) => Link(c, x),
}
}
}
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c, x) => Ok(Link(c, x)),
Link(c) => Link(c),
}
}
}
pub struct ComputeCredentials {
impl<'a, T, E> BackendType<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c) => Ok(Link(c)),
}
}
}
pub struct ComputeCredentials<T> {
pub info: ComputeUserInfo,
pub keys: ComputeCredentialKeys,
pub keys: T,
}
#[derive(Debug, Clone)]
@@ -151,6 +152,7 @@ impl ComputeUserInfo {
}
pub enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
}
@@ -185,21 +187,19 @@ async fn auth_quirks(
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
.await?;
ctx.set_endpoint_id(res.info.endpoint.clone());
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
_ => unreachable!("password hack should return a password"),
};
(res.info, Some(password))
(res.info, Some(res.keys))
}
Ok(info) => (info, None),
};
@@ -253,7 +253,7 @@ async fn authenticate_with_secret(
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
if let Some(password) = unauthenticated_password {
let auth_outcome = validate_password_and_exchange(&password, secret)?;
let keys = match auth_outcome {
@@ -275,22 +275,21 @@ async fn authenticate_with_secret(
// Perform cleartext auth if we're allowed to do that.
// Currently, we use it for websocket connections (latency).
if allow_cleartext {
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
}
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(ctx, info, client, config, secret).await
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
}
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<EndpointId> {
use BackendType::*;
match self {
Console(_, user_info) => user_info.endpoint_id.clone(),
Link(_, _) => Some("link".into()),
Link(_) => Some("link".into()),
}
}
@@ -300,7 +299,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
match self {
Console(_, user_info) => &user_info.user,
Link(_, _) => "link",
Link(_) => "link",
}
}
@@ -312,7 +311,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
use BackendType::*;
let res = match self {
@@ -323,17 +322,33 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
"performing authentication using the console"
);
let credentials =
let compute_credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
BackendType::Console(api, credentials)
let mut num_retries = 0;
let mut node =
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
(node, BackendType::Console(api, compute_credentials.info))
}
// NOTE: this auth backend doesn't use client credentials.
Link(url, _) => {
Link(url) => {
info!("performing link authentication");
let info = link::authenticate(ctx, &url, client).await?;
let node_info = link::authenticate(ctx, &url, client).await?;
BackendType::Link(url, info)
(
CachedNodeInfo::new_uncached(node_info),
BackendType::Link(url),
)
}
};
@@ -342,18 +357,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
}
}
impl BackendType<'_, ComputeUserInfo, &()> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Link(_, _) => Ok(Cached::new_uncached(None)),
}
}
impl BackendType<'_, ComputeUserInfo> {
pub async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,
@@ -361,51 +365,21 @@ impl BackendType<'_, ComputeUserInfo, &()> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
async fn wake_compute(
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
Link(_) => Ok(None),
}
}
}

View File

@@ -4,7 +4,7 @@ use crate::{
compute,
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
metrics::LatencyTimer,
sasl,
stream::{PqStream, Stream},
};
@@ -12,12 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
@@ -27,11 +27,13 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, &mut *ctx);
let scram = auth::Scram(&secret);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
// pause the timer while we communicate with the client
let _paused = latency_timer.pause();
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");
@@ -43,9 +45,9 @@ pub(super) async fn authenticate(
}
)
.await
.map_err(|e| {
.map_err(|error| {
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::AuthError::user_timeout(e)
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
})??;
let client_key = match auth_outcome {

View File

@@ -4,7 +4,7 @@ use super::{
use crate::{
auth::{self, AuthFlow},
console::AuthSecret,
context::RequestMonitoring,
metrics::LatencyTimer,
sasl,
stream::{self, Stream},
};
@@ -16,16 +16,15 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn authenticate_cleartext(
ctx: &mut RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause();
let _paused = latency_timer.pause();
let auth_outcome = AuthFlow::new(client)
.begin(auth::CleartextPassword(secret))
@@ -48,15 +47,14 @@ pub async fn authenticate_cleartext(
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub async fn password_hack_no_authentication(
ctx: &mut RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<ComputeCredentials> {
latency_timer: &mut LatencyTimer,
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause();
let _paused = latency_timer.pause();
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
@@ -73,6 +71,6 @@ pub async fn password_hack_no_authentication(
options: info.options,
endpoint: payload.endpoint,
},
keys: ComputeCredentialKeys::Password(payload.password),
keys: payload.password,
})
}

View File

@@ -2,7 +2,7 @@ use crate::{
auth, compute,
console::{self, provider::NodeInfo},
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
error::UserFacingError,
stream::PqStream,
waiters,
};
@@ -14,6 +14,10 @@ use tracing::{info, info_span};
#[derive(Debug, Error)]
pub enum LinkAuthError {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
@@ -26,16 +30,10 @@ pub enum LinkAuthError {
impl UserFacingError for LinkAuthError {
fn to_string_client(&self) -> String {
"Internal error".to_string()
}
}
impl ReportableError for LinkAuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
use LinkAuthError::*;
match self {
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
AuthFailed(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}
}
@@ -61,8 +59,6 @@ pub(super) async fn authenticate(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
ctx.set_auth_method(crate::context::AuthMethod::Web);
// registering waiter can fail if we get unlucky with rng.
// just try again.
let (psql_session_id, waiter) = loop {

View File

@@ -1,12 +1,8 @@
//! User credentials used in authentication.
use crate::{
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
};
use itertools::Itertools;
@@ -43,12 +39,6 @@ pub enum ComputeUserInfoParseError {
impl UserFacingError for ComputeUserInfoParseError {}
impl ReportableError for ComputeUserInfoParseError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -99,9 +89,6 @@ impl ComputeUserInfoMaybeEndpoint {
// record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from));
ctx.set_user(user.clone());
if let Some(dbname) = params.get("database") {
ctx.set_dbname(dbname.into());
}
// Project name might be passed via PG's command-line options.
let endpoint_option = params

View File

@@ -4,11 +4,9 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
sasl, scram,
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -25,7 +23,7 @@ pub trait AuthMethod {
pub struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
pub struct Scram<'a>(pub &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
@@ -140,11 +138,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause();
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
@@ -155,15 +148,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
return Err(super::AuthError::bad_auth_method(sasl.method));
}
match sasl.method {
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
}
_ => {}
}
info!("client chooses {}", sasl.method);
let secret = self.state.0;
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
@@ -180,7 +167,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
}
pub(crate) fn validate_password_and_exchange(
pub(super) fn validate_password_and_exchange(
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {

View File

@@ -240,9 +240,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
?unexpected,
"unexpected startup packet, rejecting connection"
);
stream
.throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
.await?
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
}
}
}
@@ -274,10 +272,5 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;
let metrics_aux: MetricsAuxInfo = Default::default();
// doesn't yet matter as pg-sni-router doesn't report analytics logs
ctx.set_success();
ctx.log();
proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await
}

View File

@@ -1,8 +1,6 @@
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
@@ -14,7 +12,6 @@ use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::redis::notifications;
use proxy::redis::publisher::RedisPublisherClient;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -25,7 +22,6 @@ use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::info;
@@ -92,9 +88,6 @@ struct ProxyCliArgs {
/// path to directory with TLS certificates for client postgres connections
#[clap(long)]
certs_dir: Option<String>,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
@@ -133,9 +126,6 @@ struct ProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
#[clap(long, default_value_t = 100)]
initial_limit: usize,
@@ -175,10 +165,6 @@ struct SqlOverHttpArgs {
#[clap(long, default_value_t = 20)]
sql_over_http_pool_max_conns_per_endpoint: usize,
/// How many connections to pool for each endpoint. Excess connections are discarded
#[clap(long, default_value_t = 20000)]
sql_over_http_pool_max_total_conns: usize,
/// How long pooled connections should remain idle for before closing
#[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
sql_over_http_idle_timeout: tokio::time::Duration,
@@ -232,19 +218,6 @@ async fn main() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &args.redis_notifications {
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
url,
args.region.clone(),
&config.redis_rps_limit,
)?))),
None => None,
};
let cancellation_handler = Arc::new(CancellationHandler::new(
cancel_map.clone(),
redis_publisher,
));
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
@@ -254,7 +227,6 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
// TODO: rename the argument to something like serverless.
@@ -269,7 +241,6 @@ async fn main() -> anyhow::Result<()> {
serverless_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
}
@@ -293,12 +264,7 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks.spawn(notifications::task_main(
url.to_owned(),
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
@@ -410,7 +376,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
}
AuthBackend::Link => {
let url = args.uri.parse()?;
auth::BackendType::Link(MaybeOwned::Owned(url), ())
auth::BackendType::Link(MaybeOwned::Owned(url))
}
};
let http_config = HttpConfig {
@@ -421,7 +387,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
};
let authentication_config = AuthenticationConfig {
@@ -430,8 +395,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let mut redis_rps_limit = args.redis_rps_limit.clone();
RateBucketInfo::validate(&mut redis_rps_limit)?;
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
@@ -443,8 +406,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
endpoint_rps_limit,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
// TODO: add this argument
region: args.region.clone(),
}));

View File

@@ -1,86 +1,25 @@
use async_trait::async_trait;
use anyhow::Context;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;
use crate::{
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
redis::publisher::RedisPublisherClient,
};
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
/// Enables serving `CancelRequest`s.
///
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler {
map: CancelMap,
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
}
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
#[derive(Debug, Error)]
pub enum CancelError {
#[error("{0}")]
IO(#[from] std::io::Error),
#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),
}
impl ReportableError for CancelError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
CancelError::IO(_) => crate::error::ErrorKind::Compute,
CancelError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
}
}
}
impl CancellationHandler {
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
Self { map, redis_client }
}
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
let from = "from_client";
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
if let Some(redis_client) = &self.redis_client {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
info!("publishing cancellation key to Redis");
match redis_client.lock().await.try_publish(key, session_id).await {
Ok(()) => {
info!("cancellation key successfuly published to Redis");
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
}
return Ok(());
};
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
let cancel_closure = self
.0
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("query cancellation key not found: {key}"))?;
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
@@ -97,7 +36,7 @@ impl CancellationHandler {
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.map.entry(key) {
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => continue,
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
@@ -109,46 +48,18 @@ impl CancellationHandler {
info!("registered new query cancellation key {key}");
Session {
key,
cancellation_handler: self,
cancel_map: self,
}
}
#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
self.map.contains_key(&session.key)
self.0.contains_key(&session.key)
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
#[async_trait]
pub trait NotificationsCancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
}
#[async_trait]
impl NotificationsCancellationHandler for CancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
let from = "from_redis";
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
match cancel_closure {
Some(cancel_closure) => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
cancel_closure.try_cancel_query().await
}
None => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
tracing::warn!("query cancellation key not found: {key}");
Ok(())
}
}
self.0.is_empty()
}
}
@@ -170,7 +81,7 @@ impl CancelClosure {
}
/// Cancels the query running on user's compute node.
async fn try_cancel_query(self) -> Result<(), CancelError> {
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
@@ -183,7 +94,7 @@ pub struct Session {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancellation_handler: Arc<CancellationHandler>,
cancel_map: Arc<CancelMap>,
}
impl Session {
@@ -191,9 +102,7 @@ impl Session {
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancellation_handler
.map
.insert(self.key, Some(cancel_closure));
self.cancel_map.0.insert(self.key, Some(cancel_closure));
self.key
}
@@ -201,7 +110,7 @@ impl Session {
impl Drop for Session {
fn drop(&mut self) {
self.cancellation_handler.map.remove(&self.key);
self.cancel_map.0.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
}
}
@@ -212,16 +121,13 @@ mod tests {
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler {
map: CancelMap::default(),
redis_client: None,
});
let cancel_map: Arc<CancelMap> = Default::default();
let session = cancellation_handler.clone().get_session();
assert!(cancellation_handler.contains(&session));
let session = cancel_map.clone().get_session();
assert!(cancel_map.contains(&session));
drop(session);
// Check that the session has been dropped.
assert!(cancellation_handler.is_empty());
assert!(cancel_map.is_empty());
Ok(())
}

View File

@@ -1,10 +1,6 @@
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
@@ -62,20 +58,6 @@ impl UserFacingError for ConnectionError {
}
}
impl ReportableError for ConnectionError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
}
}
}
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
@@ -93,7 +75,7 @@ impl ConnCfg {
}
/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: Self) {
pub fn reuse_password(&mut self, other: &Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
@@ -253,8 +235,6 @@ pub struct PostgresConnection {
pub params: std::collections::HashMap<String, String>,
/// Query cancellation token.
pub cancel_closure: CancelClosure,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
_guage: IntCounterPairGuard,
}
@@ -265,7 +245,6 @@ impl ConnCfg {
&self,
ctx: &mut RequestMonitoring,
allow_self_signed_compute: bool,
aux: MetricsAuxInfo,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
@@ -300,7 +279,6 @@ impl ConnCfg {
stream,
params,
cancel_closure,
aux,
_guage: NUM_DB_CONNECTIONS_GAUGE
.with_label_values(&[ctx.protocol])
.guard(),

View File

@@ -13,7 +13,7 @@ use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, (), ()>,
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,
@@ -21,9 +21,7 @@ pub struct ProxyConfig {
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
pub endpoint_rps_limit: Vec<RateBucketInfo>,
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
}
#[derive(Debug)]

View File

@@ -4,10 +4,7 @@ pub mod neon;
use super::messages::MetricsAuxInfo;
use crate::{
auth::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
auth::{backend::ComputeUserInfo, IpPattern},
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, ProjectInfoCacheOptions},
@@ -23,7 +20,7 @@ use tracing::info;
pub mod errors {
use crate::{
error::{io_error, ReportableError, UserFacingError},
error::{io_error, UserFacingError},
http,
proxy::retry::ShouldRetry,
};
@@ -84,15 +81,6 @@ pub mod errors {
}
}
impl ReportableError for ApiError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl ShouldRetry for ApiError {
fn could_retry(&self) -> bool {
match self {
@@ -162,16 +150,6 @@ pub mod errors {
}
}
}
impl ReportableError for GetAuthInfoError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
@@ -216,16 +194,6 @@ pub mod errors {
}
}
}
impl ReportableError for WakeComputeError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_kind(),
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
}
}
}
}
/// Auth secret which is managed by the cloud.
@@ -264,34 +232,6 @@ pub struct NodeInfo {
pub allow_self_signed_compute: bool,
}
impl NodeInfo {
pub async fn connect(
&self,
ctx: &mut RequestMonitoring,
timeout: Duration,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config
.connect(
ctx,
self.allow_self_signed_compute,
self.aux.clone(),
timeout,
)
.await
}
pub fn reuse_settings(&mut self, other: Self) {
self.allow_self_signed_compute = other.allow_self_signed_compute;
self.config.reuse_password(other.config);
}
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
};
}
}
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;

View File

@@ -176,7 +176,9 @@ impl super::Api for Api {
_ctx: &mut RequestMonitoring,
_user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute().map_ok(Cached::new_uncached).await
self.do_wake_compute()
.map_ok(CachedNodeInfo::new_uncached)
.await
}
}

View File

@@ -188,7 +188,6 @@ impl super::Api for Api {
ep,
Arc::new(auth_info.allowed_ips),
);
ctx.set_project_id(project_id);
}
// When we just got a secret, we don't need to invalidate it.
Ok(Cached::new_uncached(auth_info.secret))
@@ -222,7 +221,6 @@ impl super::Api for Api {
self.caches
.project_info
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
ctx.set_project_id(project_id);
}
Ok((
Cached::new_uncached(allowed_ips),

View File

@@ -8,10 +8,8 @@ use tokio::sync::mpsc;
use uuid::Uuid;
use crate::{
console::messages::MetricsAuxInfo,
error::ErrorKind,
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
BranchId, DbName, EndpointId, ProjectId, RoleName,
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
EndpointId, ProjectId, RoleName,
};
pub mod parquet;
@@ -34,11 +32,9 @@ pub struct RequestMonitoring {
project: Option<ProjectId>,
branch: Option<BranchId>,
endpoint_id: Option<EndpointId>,
dbname: Option<DbName>,
user: Option<RoleName>,
application: Option<SmolStr>,
error_kind: Option<ErrorKind>,
pub(crate) auth_method: Option<AuthMethod>,
success: bool,
// extra
@@ -47,15 +43,6 @@ pub struct RequestMonitoring {
pub latency_timer: LatencyTimer,
}
#[derive(Clone, Debug)]
pub enum AuthMethod {
// aka link aka passwordless
Web,
ScramSha256,
ScramSha256Plus,
Cleartext,
}
impl RequestMonitoring {
pub fn new(
session_id: Uuid,
@@ -73,11 +60,9 @@ impl RequestMonitoring {
project: None,
branch: None,
endpoint_id: None,
dbname: None,
user: None,
application: None,
error_kind: None,
auth_method: None,
success: false,
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
@@ -104,10 +89,6 @@ impl RequestMonitoring {
self.project = Some(x.project_id);
}
pub fn set_project_id(&mut self, project_id: ProjectId) {
self.project = Some(project_id);
}
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
crate::metrics::CONNECTING_ENDPOINTS
.with_label_values(&[self.protocol])
@@ -119,30 +100,10 @@ impl RequestMonitoring {
self.application = app.or_else(|| self.application.clone());
}
pub fn set_dbname(&mut self, dbname: DbName) {
self.dbname = Some(dbname);
}
pub fn set_user(&mut self, user: RoleName) {
self.user = Some(user);
}
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
self.auth_method = Some(auth_method);
}
pub fn set_error_kind(&mut self, kind: ErrorKind) {
ERROR_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.inc();
if let Some(ep) = &self.endpoint_id {
ENDPOINT_ERRORS_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.measure(ep);
}
self.error_kind = Some(kind);
}
pub fn set_success(&mut self) {
self.success = true;
}

View File

@@ -84,10 +84,8 @@ struct RequestData {
username: Option<String>,
application_name: Option<String>,
endpoint_id: Option<String>,
database: Option<String>,
project: Option<String>,
branch: Option<String>,
auth_method: Option<&'static str>,
error: Option<&'static str>,
/// Success is counted if we form a HTTP response with sql rows inside
/// Or if we make it to proxy_pass
@@ -106,18 +104,11 @@ impl From<RequestMonitoring> for RequestData {
username: value.user.as_deref().map(String::from),
application_name: value.application.as_deref().map(String::from),
endpoint_id: value.endpoint_id.as_deref().map(String::from),
database: value.dbname.as_deref().map(String::from),
project: value.project.as_deref().map(String::from),
branch: value.branch.as_deref().map(String::from),
auth_method: value.auth_method.as_ref().map(|x| match x {
super::AuthMethod::Web => "web",
super::AuthMethod::ScramSha256 => "scram_sha_256",
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
super::AuthMethod::Cleartext => "cleartext",
}),
protocol: value.protocol,
region: value.region,
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
error: value.error_kind.as_ref().map(|e| e.to_str()),
success: value.success,
duration_us: SystemTime::from(value.first_packet)
.elapsed()
@@ -440,10 +431,8 @@ mod tests {
application_name: Some("test".to_owned()),
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
auth_method: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
error: None,
@@ -516,15 +505,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1313727, 3, 6000),
(1313720, 3, 6000),
(1313780, 3, 6000),
(1313737, 3, 6000),
(1313867, 3, 6000),
(1313709, 3, 6000),
(1313501, 3, 6000),
(1313737, 3, 6000),
(438118, 1, 2000)
(1087635, 3, 6000),
(1087288, 3, 6000),
(1087444, 3, 6000),
(1087572, 3, 6000),
(1087468, 3, 6000),
(1087500, 3, 6000),
(1087533, 3, 6000),
(1087566, 3, 6000),
(362671, 1, 2000)
],
);
@@ -554,11 +543,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1219459, 5, 10000),
(1225609, 5, 10000),
(1227403, 5, 10000),
(1226765, 5, 10000),
(1218043, 5, 10000)
(1028637, 5, 10000),
(1031969, 5, 10000),
(1019900, 5, 10000),
(1020365, 5, 10000),
(1025010, 5, 10000)
],
);
@@ -590,11 +579,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1205106, 5, 10000),
(1204837, 5, 10000),
(1205130, 5, 10000),
(1205118, 5, 10000),
(1205373, 5, 10000)
(1210770, 6, 12000),
(1211036, 6, 12000),
(1210990, 6, 12000),
(1210861, 6, 12000),
(202073, 1, 2000)
],
);
@@ -619,15 +608,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1313727, 3, 6000),
(1313720, 3, 6000),
(1313780, 3, 6000),
(1313737, 3, 6000),
(1313867, 3, 6000),
(1313709, 3, 6000),
(1313501, 3, 6000),
(1313737, 3, 6000),
(438118, 1, 2000)
(1087635, 3, 6000),
(1087288, 3, 6000),
(1087444, 3, 6000),
(1087572, 3, 6000),
(1087468, 3, 6000),
(1087500, 3, 6000),
(1087533, 3, 6000),
(1087566, 3, 6000),
(362671, 1, 2000)
],
);
@@ -664,7 +653,7 @@ mod tests {
// files are smaller than the size threshold, but they took too long to fill so were flushed early
assert_eq!(
file_stats,
[(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)],
[(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)],
);
tmpdir.close().unwrap();

View File

@@ -17,7 +17,7 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
/// is way too convenient and tends to proliferate all across the codebase,
/// ultimately leading to accidental leaks of sensitive data.
pub trait UserFacingError: ReportableError {
pub trait UserFacingError: fmt::Display {
/// Format the error for client, stripping all sensitive info.
///
/// Although this might be a no-op for many types, it's highly
@@ -29,13 +29,13 @@ pub trait UserFacingError: ReportableError {
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Clone)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
/// Network error between user and proxy. Not necessarily user error
ClientDisconnect,
Disconnect,
/// Proxy self-imposed rate limits
RateLimit,
@@ -46,9 +46,6 @@ pub enum ErrorKind {
/// Error communicating with control plane
ControlPlane,
/// Postgres error
Postgres,
/// Error communicating with compute
Compute,
}
@@ -57,46 +54,11 @@ impl ErrorKind {
pub fn to_str(&self) -> &'static str {
match self {
ErrorKind::User => "request failed due to user error",
ErrorKind::ClientDisconnect => "client disconnected",
ErrorKind::Disconnect => "client disconnected",
ErrorKind::RateLimit => "request cancelled due to rate limit",
ErrorKind::Service => "internal service error",
ErrorKind::ControlPlane => "non-retryable control plane error",
ErrorKind::Postgres => "postgres error",
ErrorKind::Compute => {
"non-retryable compute connection error (or exhausted retry capacity)"
}
}
}
pub fn to_metric_label(&self) -> &'static str {
match self {
ErrorKind::User => "user",
ErrorKind::ClientDisconnect => "clientdisconnect",
ErrorKind::RateLimit => "ratelimit",
ErrorKind::Service => "service",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "postgres",
ErrorKind::Compute => "compute",
}
}
}
pub trait ReportableError: fmt::Display + Send + 'static {
fn get_error_kind(&self) -> ErrorKind;
}
impl ReportableError for tokio::time::error::Elapsed {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::RateLimit
}
}
impl ReportableError for tokio_postgres::error::Error {
fn get_error_kind(&self) -> ErrorKind {
if self.as_db_error().is_some() {
ErrorKind::Postgres
} else {
ErrorKind::Compute
ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)",
}
}
}

View File

@@ -1,10 +1,8 @@
use ::metrics::{
exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge,
register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec,
IntCounterVec, IntGauge, IntGaugeVec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge_vec, Histogram,
HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGaugeVec,
};
use metrics::{register_int_counter_pair, IntCounterPair};
use once_cell::sync::Lazy;
use tokio::time;
@@ -114,53 +112,6 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
.unwrap()
});
pub static HTTP_CONTENT_LENGTH: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_http_conn_content_length_bytes",
"Time it took for proxy to establish a connection to the compute endpoint",
// largest bucket = 3^16 * 0.05ms = 2.15s
exponential_buckets(8.0, 2.0, 20).unwrap()
)
.unwrap()
});
pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_http_pool_reclaimation_lag_seconds",
"Time it takes to reclaim unused connection pools",
// 1us -> 65ms
exponential_buckets(1e-6, 2.0, 16).unwrap(),
)
.unwrap()
});
pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
register_int_counter_pair!(
"proxy_http_pool_endpoints_registered_total",
"Number of endpoints we have registered pools for",
"proxy_http_pool_endpoints_unregistered_total",
"Number of endpoints we have unregistered pools for",
)
.unwrap()
});
pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"proxy_http_pool_opened_connections",
"Number of opened connections to a database.",
)
.unwrap()
});
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_cancellation_requests_total",
"Number of cancellation requests (per found/not_found).",
&["source", "kind"],
)
.unwrap()
});
#[derive(Clone)]
pub struct LatencyTimer {
// time since the stopwatch was started
@@ -209,9 +160,8 @@ impl LatencyTimer {
pub fn success(&mut self) {
// stop the stopwatch and record the time that we have accumulated
if let Some(start) = self.start.take() {
self.accumulated += start.elapsed();
}
let start = self.start.take().expect("latency timer should be started");
self.accumulated += start.elapsed();
// success
self.outcome = "success";
@@ -284,22 +234,3 @@ pub static CONNECTING_ENDPOINTS: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
)
.unwrap()
});
pub static ERROR_BY_KIND: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_errors_total",
"Number of errors by a given classification",
&["type"],
)
.unwrap()
});
pub static ENDPOINT_ERRORS_BY_KIND: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
register_hll_vec!(
32,
"proxy_endpoints_affected_by_errors",
"Number of endpoints affected by errors of a given classification",
&["type"],
)
.unwrap()
});

View File

@@ -2,7 +2,6 @@
mod tests;
pub mod connect_compute;
mod copy_bidirectional;
pub mod handshake;
pub mod passthrough;
pub mod retry;
@@ -10,14 +9,13 @@ pub mod wake_compute;
use crate::{
auth,
cancellation::{self, CancellationHandler},
cancellation::{self, CancelMap},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
protocol2::WithClientIp,
proxy::handshake::{handshake, HandshakeData},
proxy::{handshake::handshake, passthrough::proxy_pass},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey,
@@ -30,17 +28,14 @@ use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use self::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
use self::connect_compute::{connect_to_compute, TcpMechanism};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
@@ -62,7 +57,6 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -73,6 +67,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancel_map = Arc::new(CancelMap::default());
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -80,7 +75,7 @@ pub async fn task_main(
let (socket, peer_addr) = accept_result?;
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancel_map = Arc::clone(&cancel_map);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let session_span = info_span!(
@@ -103,41 +98,22 @@ pub async fn task_main(
bail!("missing required client IP");
}
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
socket
.inner
.set_nodelay(true)
.context("failed to set socket option")?;
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
let res = handle_client(
handle_client(
config,
&mut ctx,
cancellation_handler,
cancel_map,
socket,
ClientMode::Tcp,
endpoint_rate_limiter,
)
.await;
match res {
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
ctx.log();
Err(e.into())
}
Ok(None) => {
ctx.set_success();
ctx.log();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log();
p.proxy_pass().await
}
}
.await
}
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
@@ -163,14 +139,14 @@ pub enum ClientMode {
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub fn allow_cleartext(&self) -> bool {
fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
match self {
ClientMode::Tcp => config.allow_self_signed_compute,
ClientMode::Websockets { .. } => false,
@@ -193,45 +169,14 @@ impl ClientMode {
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] handshake::HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
cancellation_handler: Arc<CancellationHandler>,
cancel_map: Arc<CancelMap>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
) -> anyhow::Result<()> {
info!(
protocol = ctx.protocol,
"handling interactive connection from client"
@@ -248,17 +193,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let tls = config.tls_config.as_ref();
let pause = ctx.latency_timer.pause();
let do_handshake = handshake(stream, mode.handshake_tls(tls));
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id)
.await
.map(|()| None)?)
}
};
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
drop(pause);
let hostname = mode.hostname(stream.get_ref());
@@ -282,12 +221,12 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await?;
.await;
}
}
let user = user_info.get_user().to_owned();
let user_info = match user_info
let (mut node_info, user_info) = match user_info
.authenticate(
ctx,
&mut stream,
@@ -302,20 +241,23 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream.throw_error(e).instrument(params_span).await?;
return stream.throw_error(e).instrument(params_span).await;
}
};
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
let aux = node_info.aux.clone();
let mut node = connect_to_compute(
ctx,
&TcpMechanism { params: &params },
node_info,
&user_info,
mode.allow_self_signed_compute(config),
)
.or_else(|e| stream.throw_error(e))
.await?;
let session = cancellation_handler.get_session();
let session = cancel_map.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
@@ -325,14 +267,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
compute: node,
req: _request_gauge,
conn: _client_gauge,
cancel: session,
}))
proxy_pass(ctx, stream, node.stream, aux).await
}
/// Finish client connection initialization: confirm auth success, send params, etc.
@@ -341,7 +276,7 @@ async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: &cancellation::Session,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<(), std::io::Error> {
) -> anyhow::Result<()> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());

View File

@@ -1,9 +1,8 @@
use crate::{
auth::backend::ComputeCredentialKeys,
auth,
compute::{self, PostgresConnection},
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
console::{self, errors::WakeComputeError},
context::RequestMonitoring,
error::ReportableError,
metrics::NUM_CONNECTION_FAILURES,
proxy::{
retry::{retry_after, ShouldRetry},
@@ -21,7 +20,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(name = "invalidate_cache", skip_all)]
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
@@ -32,13 +31,28 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
node_info.invalidate()
node_info.invalidate().config
}
/// Try to connect to the compute node once.
#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)]
async fn connect_to_compute_once(
ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, compute::ConnectionError> {
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(ctx, allow_self_signed_compute, timeout)
.await
}
#[async_trait]
pub trait ConnectMechanism {
type Connection;
type ConnectError: ReportableError;
type ConnectError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
@@ -50,16 +64,6 @@ pub trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
pub trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
}
pub struct TcpMechanism<'a> {
/// KV-dictionary with PostgreSQL connection params.
pub params: &'a StartupMessageParams,
@@ -71,14 +75,13 @@ impl ConnectMechanism for TcpMechanism<'_> {
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
node_info.connect(ctx, timeout).await
connect_to_compute_once(ctx, node_info, timeout).await
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -89,23 +92,16 @@ impl ConnectMechanism for TcpMechanism<'_> {
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
pub async fn connect_to_compute<M: ConnectMechanism>(
ctx: &mut RequestMonitoring,
mechanism: &M,
user_info: &B,
allow_self_signed_compute: bool,
mut node_info: console::CachedNodeInfo,
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: ShouldRetry + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
if let Some(keys) = user_info.get_keys() {
node_info.set_keys(keys);
}
node_info.allow_self_signed_compute = allow_self_signed_compute;
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -122,30 +118,28 @@ where
error!(error = ?err, "could not connect to compute node");
let node_info = if !node_info.cached() {
// If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if !err.should_retry(num_retries) {
return Err(err.into());
}
node_info
} else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
ctx.latency_timer.cache_miss();
let old_node_info = invalidate_cache(node_info);
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
node_info.reuse_settings(old_node_info);
let mut num_retries = 1;
mechanism.update_connect_config(&mut node_info.config);
node_info
match user_info {
auth::BackendType::Console(api, info) => {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
ctx.latency_timer.cache_miss();
let config = invalidate_cache(node_info);
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
node_info.config.reuse_password(&config);
mechanism.update_connect_config(&mut node_info.config);
}
// nothing to do?
auth::BackendType::Link(_) => {}
};
// now that we have a new node, try connect to it repeatedly.
// this can error for a few reasons, for instance:
// * DNS connection settings haven't quite propagated yet
info!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)

View File

@@ -1,256 +0,0 @@
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
#[derive(Debug)]
enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}
pub(super) async fn copy_bidirectional<A, B>(
a: &mut A,
b: &mut B,
) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut a_to_b = TransferState::Running(CopyBuffer::new());
let mut b_to_a = TransferState::Running(CopyBuffer::new());
poll_fn(|cx| {
let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
// Early termination checks
if let TransferState::Done(_) = a_to_b {
if let TransferState::Running(buf) = &b_to_a {
// Initiate shutdown
b_to_a = TransferState::ShuttingDown(buf.amt);
b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
}
}
if let TransferState::Done(_) = b_to_a {
if let TransferState::Running(buf) = &a_to_b {
// Initiate shutdown
a_to_b = TransferState::ShuttingDown(buf.amt);
a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
}
}
// It is not a problem if ready! returns early ... (comment remains the same)
let a_to_b = ready!(a_to_b_result);
let b_to_a = ready!(b_to_a_result);
Poll::Ready(Ok((a_to_b, b_to_a)))
})
.await
}
#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
impl CopyBuffer {
pub(super) fn new() -> Self {
Self {
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
}
}
fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
buf.set_filled(me.cap);
let res = reader.poll_read(cx, &mut buf);
if let Poll::Ready(Ok(())) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
}
fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
}
Poll::Pending
}
res => res,
}
}
pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
self.pos = 0;
self.cap = 0;
match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
self.need_flush = false;
}
return Poll::Pending;
}
}
}
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
}
// If pos larger than cap, this loop will never stop.
// In particular, user's wrong poll_write implementation returning
// incorrect written length may lead to thread blocking.
debug_assert!(
self.pos <= self.cap,
"writer returned length larger than input slice"
);
// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_early_termination_a_to_d() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
a_mock.write_all(b"hello").await.unwrap();
a_mock.shutdown().await.unwrap();
d_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(a_to_d_count, 5); // 'hello' was transferred
assert!(d_to_a_count <= 8); // response only partially transferred or not at all
}
#[tokio::test]
async fn test_early_termination_d_to_a() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
d_mock.write_all(b"hello").await.unwrap();
d_mock.shutdown().await.unwrap();
a_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(d_to_a_count, 5); // 'hello' was transferred
assert!(a_to_d_count <= 8); // response only partially transferred or not at all
}
}

View File

@@ -1,60 +1,15 @@
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
use thiserror::Error;
use anyhow::{bail, Context};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use crate::{
cancellation::CancelMap,
config::TlsConfig,
error::ReportableError,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
stream::{PqStream, Stream},
};
#[derive(Error, Debug)]
pub enum HandshakeError {
#[error("data is sent before server replied with EncryptionResponse")]
EarlyData,
#[error("protocol violation")]
ProtocolViolation,
#[error("missing certificate")]
MissingCertificate,
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for HandshakeError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
},
HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
HandshakeError::ReportedError(e) => e.get_error_kind(),
}
}
}
pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
@@ -63,7 +18,8 @@ pub enum HandshakeData<S> {
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
) -> Result<HandshakeData<S>, HandshakeError> {
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -93,14 +49,14 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
return Err(HandshakeError::EarlyData);
bail!("data is sent before server replied with EncryptionResponse");
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.ok_or(HandshakeError::MissingCertificate)?;
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
@@ -108,7 +64,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
});
}
}
_ => return Err(HandshakeError::ProtocolViolation),
_ => bail!(ERR_PROTO_VIOLATION),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
@@ -117,23 +73,23 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => return Err(HandshakeError::ProtocolViolation),
_ => bail!(ERR_PROTO_VIOLATION),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
return stream
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
.await?;
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(HandshakeData::Startup(stream, params));
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data));
break Ok(None);
}
}
}

Some files were not shown because too many files have changed in this diff Show More