mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-06 04:00:37 +00:00
Compare commits
6 Commits
add_audit_
...
detect-new
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eba6f85909 | ||
|
|
52888e06a0 | ||
|
|
2b51b7cbb1 | ||
|
|
5895abf495 | ||
|
|
05559539f1 | ||
|
|
8f2808864b |
@@ -43,7 +43,7 @@ runs:
|
||||
BUCKET: neon-github-public-dev
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
BRANCH_OR_PR=pr-${PR_NUMBER}
|
||||
BRANCH_OR_PR=pr-${PR_NUMBER}${REPORT_EXT-}
|
||||
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
|
||||
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
|
||||
# Shortcut for special branches
|
||||
|
||||
@@ -23,7 +23,7 @@ runs:
|
||||
REPORT_DIR: ${{ inputs.report-dir }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
BRANCH_OR_PR=pr-${PR_NUMBER}
|
||||
BRANCH_OR_PR=pr-${PR_NUMBER}${REPORT_EXT-}
|
||||
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
|
||||
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
|
||||
# Shortcut for special branches
|
||||
|
||||
14
.github/actions/run-python-test-set/action.yml
vendored
14
.github/actions/run-python-test-set/action.yml
vendored
@@ -12,6 +12,10 @@ inputs:
|
||||
description: 'Arbitrary parameters to pytest. For example "-s" to prevent capturing stdout/stderr'
|
||||
required: false
|
||||
default: ''
|
||||
extended_testing:
|
||||
description: 'Set to true if the test results should be stored and processed separately'
|
||||
required: false
|
||||
default: 'false'
|
||||
needs_postgres_source:
|
||||
description: 'Set to true if the test suite requires postgres source checked out'
|
||||
required: false
|
||||
@@ -135,13 +139,15 @@ runs:
|
||||
PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)"
|
||||
echo "PERF_REPORT_DIR=${PERF_REPORT_DIR}" >> ${GITHUB_ENV}
|
||||
rm -rf $PERF_REPORT_DIR
|
||||
|
||||
TEST_SELECTION="test_runner/${{ inputs.test_selection }}"
|
||||
EXTRA_PARAMS="${{ inputs.extra_params }}"
|
||||
TEST_SELECTION="${{ inputs.test_selection }}"
|
||||
if [ -z "$TEST_SELECTION" ]; then
|
||||
echo "test_selection must be set"
|
||||
exit 1
|
||||
fi
|
||||
if [[ $TEST_SELECTION != test_runner/* ]]; then
|
||||
TEST_SELECTION="test_runner/$TEST_SELECTION"
|
||||
fi
|
||||
EXTRA_PARAMS="${{ inputs.extra_params }}"
|
||||
if [[ "${{ inputs.run_in_parallel }}" == "true" ]]; then
|
||||
# -n sets the number of parallel processes that pytest-xdist will run
|
||||
EXTRA_PARAMS="-n12 $EXTRA_PARAMS"
|
||||
@@ -244,3 +250,5 @@ runs:
|
||||
report-dir: /tmp/test_output/allure/results
|
||||
unique-key: ${{ inputs.build_type }}-${{ inputs.pg_version }}-${{ runner.arch }}
|
||||
aws-oidc-role-arn: ${{ inputs.aws-oidc-role-arn }}
|
||||
env:
|
||||
REPORT_EXT: ${{ inputs.extended_testing == 'true' && '-ext' || '' }}
|
||||
|
||||
143
.github/scripts/detect-updated-pytests.py
vendored
Executable file
143
.github/scripts/detect-updated-pytests.py
vendored
Executable file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
commit_sha = os.getenv("COMMIT_SHA")
|
||||
base_sha = os.getenv("BASE_SHA")
|
||||
|
||||
cmd = ["git", "merge-base", base_sha, commit_sha]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
if result.returncode != 0 or not (baseline := result.stdout.strip()):
|
||||
print("Baseline commit for PR is not found, detection skipped.")
|
||||
sys.exit(0)
|
||||
print(f"Baseline commit: {baseline}")
|
||||
|
||||
cmd = ["git", "diff", "--name-only", f"{baseline}..{commit_sha}", "test_runner/regress/"]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
if result.returncode != 0:
|
||||
print(f"Git diff returned code {result.returncode}\n{result.stdout}\nDetection skipped.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def collect_tests(test_file_name):
|
||||
cmd = ["./scripts/pytest", "--collect-only", "-q", test_file_name]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
if result.returncode != 0:
|
||||
print(
|
||||
f"pytest --collect-only returned code {result.returncode}\n{result.stdout}\nDetection skipped."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
tests = []
|
||||
for test_item in result.stdout.split("\n"):
|
||||
if not test_item.startswith(test_file_name):
|
||||
break
|
||||
test_name = re.sub(r"(.*::)([^\[]+)(\[.*)", r"\2", test_item)
|
||||
if test_name not in tests:
|
||||
tests.append(test_name)
|
||||
return tests
|
||||
|
||||
|
||||
all_new_tests = []
|
||||
all_updated_tests = []
|
||||
temp_test_file = "test_runner/regress/__temp__.py"
|
||||
temp_file = None
|
||||
for test_file in result.stdout.split("\n"):
|
||||
if not test_file:
|
||||
continue
|
||||
print(f"Test file modified: {test_file}.")
|
||||
|
||||
# Get and compare two lists of items collected by pytest to detect new tests in the PR
|
||||
if temp_file:
|
||||
temp_file.close()
|
||||
temp_file = open(temp_test_file, "w")
|
||||
cmd = ["git", "show", f"{baseline}:{test_file}"]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(cmd, text=True, stdout=temp_file)
|
||||
if result.returncode != 0:
|
||||
tests0 = []
|
||||
else:
|
||||
tests0 = collect_tests(temp_test_file)
|
||||
|
||||
tests1 = collect_tests(test_file)
|
||||
|
||||
new_tests = set(tests1).difference(tests0)
|
||||
for test_name in new_tests:
|
||||
all_new_tests.append(f"{test_file}::{test_name}")
|
||||
|
||||
# Detect pre-existing test functions updated in the PR
|
||||
cmd = ["git", "diff", f"{baseline}..{commit_sha}", test_file]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
if result.returncode != 0:
|
||||
print(f"Git diff returned code {result.returncode}\n{result.stdout}\nDetection skipped.")
|
||||
sys.exit(0)
|
||||
updated_funcs = []
|
||||
for diff_line in result.stdout.split("\n"):
|
||||
print(diff_line)
|
||||
# TODO: detect functions with added/modified parameters
|
||||
if not diff_line.startswith("@@"):
|
||||
continue
|
||||
|
||||
# Extract names of functions with updated content relying on hunk header
|
||||
m = re.match(r"^(@@[0-9, +-]+@@ def )([^(]+)(.*)", diff_line)
|
||||
if not m:
|
||||
continue
|
||||
func_name = m.group(2)
|
||||
print(func_name) ##
|
||||
|
||||
# Ignore functions not collected by pytest
|
||||
if func_name not in tests1:
|
||||
continue
|
||||
if func_name not in updated_funcs:
|
||||
updated_funcs.append(func_name)
|
||||
|
||||
for func_name in updated_funcs:
|
||||
print(f"Function modified: {func_name}.")
|
||||
# Extract changes within the function
|
||||
|
||||
cmd = ["git", "log", f"{baseline}..{commit_sha}", "-L", f":{func_name}:{test_file}"]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
if result.returncode != 0:
|
||||
continue
|
||||
|
||||
patch_contents = result.stdout
|
||||
|
||||
# Revert changes to get the file with only this function updated
|
||||
# (applying the patch might fail if it contains a change for the next function declaraion)
|
||||
shutil.copy(test_file, temp_test_file)
|
||||
|
||||
cmd = ["patch", "-R", "-p1", "--no-backup-if-mismatch", "-r", "/dev/null", temp_test_file]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(
|
||||
cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, input=patch_contents
|
||||
)
|
||||
print(f"result: {result.returncode}; {result.stdout}")
|
||||
if result.returncode != 0:
|
||||
continue
|
||||
|
||||
# Ignore whitespace-only changes
|
||||
cmd = ["diff", "-w", test_file, temp_test_file]
|
||||
print(f"Running: {' '.join(cmd)}...")
|
||||
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
if result.returncode == 0:
|
||||
continue
|
||||
all_updated_tests.append(f"{test_file}::{func_name}")
|
||||
|
||||
if temp_file:
|
||||
temp_file.close()
|
||||
if os.path.exists(temp_test_file):
|
||||
os.remove(temp_test_file)
|
||||
|
||||
if github_output := os.getenv("GITHUB_OUTPUT"):
|
||||
with open(github_output, "a") as f:
|
||||
if all_new_tests or all_updated_tests:
|
||||
f.write("tests=")
|
||||
f.write(" ".join(all_new_tests + all_updated_tests))
|
||||
f.write("\n")
|
||||
@@ -266,7 +266,7 @@ jobs:
|
||||
role-duration-seconds: 18000 # 5 hours
|
||||
|
||||
- name: Run rust tests
|
||||
if: ${{ inputs.sanitizers != 'enabled' }}
|
||||
if: ${{ inputs.sanitizers != 'enabled' && inputs.test-selection == '' }}
|
||||
env:
|
||||
NEXTEST_RETRIES: 3
|
||||
run: |
|
||||
@@ -386,7 +386,8 @@ jobs:
|
||||
timeout-minutes: ${{ inputs.sanitizers != 'enabled' && 75 || 180 }}
|
||||
with:
|
||||
build_type: ${{ inputs.build-type }}
|
||||
test_selection: regress
|
||||
test_selection: ${{ inputs.test-selection != '' && inputs.test-selection || 'regress' }}
|
||||
extended_testing: ${{ inputs.test-selection != '' && 'true' || 'false' }}
|
||||
needs_postgres_source: true
|
||||
run_with_real_s3: true
|
||||
real_s3_bucket: neon-github-ci-tests
|
||||
@@ -399,9 +400,7 @@ jobs:
|
||||
# Attempt to stop tests gracefully to generate test reports
|
||||
# until they are forcibly stopped by the stricter `timeout-minutes` limit.
|
||||
extra_params: --session-timeout=${{ inputs.sanitizers != 'enabled' && 3000 || 10200 }} --count=${{ inputs.test-run-count }}
|
||||
${{ inputs.test-selection != '' && format('-k "{0}"', inputs.test-selection) || '' }}
|
||||
env:
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
|
||||
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
|
||||
BUILD_TAG: ${{ inputs.build-tag }}
|
||||
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring
|
||||
|
||||
6
.github/workflows/build_and_test.yml
vendored
6
.github/workflows/build_and_test.yml
vendored
@@ -199,6 +199,12 @@ jobs:
|
||||
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
|
||||
secrets: inherit
|
||||
|
||||
build-and-test-new-tests:
|
||||
needs: [ meta, build-build-tools-image ]
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: ./.github/workflows/build_and_test_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
build-and-test-locally:
|
||||
needs: [ meta, build-build-tools-image ]
|
||||
# We do need to run this in `.*-rc-pr` because of hotfixes.
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
name: Build and Run Selected Test
|
||||
name: Build and Run Selected Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test-selection:
|
||||
description: 'Specification of selected test(s), as accepted by pytest -k'
|
||||
description: 'Specification of selected test(s), e. g.: test_runner/regress/test_pg_regress.py::test_pg_regress'
|
||||
required: true
|
||||
type: string
|
||||
run-count:
|
||||
@@ -26,6 +26,8 @@ on:
|
||||
default: '[{"pg_version":"v17"}]'
|
||||
required: true
|
||||
type: string
|
||||
workflow_call:
|
||||
pull_request: # TODO: remove before merge
|
||||
|
||||
defaults:
|
||||
run:
|
||||
@@ -42,26 +44,71 @@ jobs:
|
||||
github-event-name: ${{ github.event_name }}
|
||||
github-event-json: ${{ toJSON(github.event) }}
|
||||
|
||||
build-and-test-locally:
|
||||
needs: [ meta ]
|
||||
|
||||
choose-test-parameters:
|
||||
runs-on: [ self-hosted, small ]
|
||||
container:
|
||||
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
|
||||
credentials:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
options: --init
|
||||
|
||||
outputs:
|
||||
tests: ${{ inputs.test-selection != '' && inputs.test-selection || steps.detect_tests_to_test.outputs.tests }}
|
||||
archs: ${{ inputs.test-selection != '' && inputs.archs || '["x64", "arm64"]' }}
|
||||
build-types: ${{ inputs.test-selection != '' && inputs.build-types || '["release"]' }}
|
||||
pg-versions: ${{ inputs.test-selection != '' && inputs.pg-versions || '[{"pg_version":"v14"}, {"pg_version":"v17"}]' }}
|
||||
run-count: ${{ inputs.test-selection != '' && inputs.run-count || 5 }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
if: inputs.test-selection == ''
|
||||
with:
|
||||
submodules: false
|
||||
clean: false
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Cache poetry deps
|
||||
if: inputs.test-selection == ''
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry/virtualenvs
|
||||
key: v2-${{ runner.os }}-${{ runner.arch }}-python-deps-bookworm-${{ hashFiles('poetry.lock') }}
|
||||
|
||||
- name: Install Python deps
|
||||
if: inputs.test-selection == ''
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: ./scripts/pysync
|
||||
|
||||
- name: Detect new and updated tests
|
||||
id: detect_tests_to_test
|
||||
if: github.event.pull_request.head.sha && inputs.test-selection == ''
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }}
|
||||
run: python3 .github/scripts/detect-updated-pytests.py
|
||||
|
||||
build-and-test-tests:
|
||||
needs: [ meta, choose-test-parameters ]
|
||||
if: needs.choose-test-parameters.outputs.tests != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: ${{ fromJson(inputs.archs) }}
|
||||
build-type: ${{ fromJson(inputs.build-types) }}
|
||||
arch: ${{ fromJson(needs.choose-test-parameters.outputs.archs) }}
|
||||
build-type: ${{ fromJson(needs.choose-test-parameters.outputs.build-types) }}
|
||||
uses: ./.github/workflows/_build-and-test-locally.yml
|
||||
with:
|
||||
arch: ${{ matrix.arch }}
|
||||
build-tools-image: ghcr.io/neondatabase/build-tools:pinned-bookworm
|
||||
build-tag: ${{ needs.meta.outputs.build-tag }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
test-cfg: ${{ inputs.pg-versions }}
|
||||
test-selection: ${{ inputs.test-selection }}
|
||||
test-run-count: ${{ fromJson(inputs.run-count) }}
|
||||
test-cfg: ${{ needs.choose-test-parameters.outputs.pg-versions }}
|
||||
test-selection: ${{ needs.choose-test-parameters.outputs.tests }}
|
||||
test-run-count: ${{ fromJson(needs.choose-test-parameters.outputs.run-count) }}
|
||||
secrets: inherit
|
||||
|
||||
create-test-report:
|
||||
needs: [ build-and-test-locally ]
|
||||
needs: [ build-and-test-tests ]
|
||||
if: ${{ !cancelled() }}
|
||||
permissions:
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
@@ -96,6 +143,7 @@ jobs:
|
||||
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
|
||||
env:
|
||||
REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_DEV }}
|
||||
REPORT_EXT: '-ext'
|
||||
|
||||
- uses: actions/github-script@v7
|
||||
if: ${{ !cancelled() }}
|
||||
@@ -168,35 +168,6 @@ pub fn write_postgres_conf(
|
||||
writeln!(file, "# Managed by compute_ctl: end")?;
|
||||
}
|
||||
|
||||
// Always add pgaudit to shared_preload_libraries.
|
||||
//
|
||||
// This is needed to handle the downgrade scenario.
|
||||
// pgaudit extension creates event triggers that require library to be loaded.
|
||||
// so, once extension was installed it must always be present in shared_preload_libraries.
|
||||
let mut extra_shared_preload_libraries = String::new();
|
||||
|
||||
let libs = {
|
||||
// We don't distribute pgaudit in the testing image,
|
||||
// and don't pass shared_preload_libraries via spec,
|
||||
// so disable this logic there.
|
||||
#[cfg(feature = "testing")]
|
||||
{
|
||||
String::new()
|
||||
}
|
||||
#[cfg(not(feature = "testing"))]
|
||||
{
|
||||
spec.cluster
|
||||
.settings
|
||||
.find("shared_preload_libraries")
|
||||
.expect("shared_preload_libraries setting is missing in the spec")
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "testing"))]
|
||||
if !libs.contains("pgaudit") {
|
||||
extra_shared_preload_libraries.push_str(",pgaudit");
|
||||
};
|
||||
|
||||
// If base audit logging is enabled, configure it.
|
||||
// In this setup, the audit log will be written to the standard postgresql log.
|
||||
//
|
||||
@@ -206,22 +177,29 @@ pub fn write_postgres_conf(
|
||||
// This way we always override the settings from the spec
|
||||
// and don't allow the user or the control plane admin to change them.
|
||||
match spec.audit_log_level {
|
||||
ComputeAudit::Disabled => {
|
||||
// this is the default, but let's be explicit
|
||||
writeln!(file, "pgaudit.log='none'")?;
|
||||
}
|
||||
ComputeAudit::Disabled => {}
|
||||
ComputeAudit::Log | ComputeAudit::Base => {
|
||||
writeln!(file, "# Managed by compute_ctl base audit settings: start")?;
|
||||
writeln!(file, "pgaudit.log='ddl,role'")?;
|
||||
// Disable logging of catalog queries to reduce the noise
|
||||
writeln!(file, "pgaudit.log_catalog=off")?;
|
||||
|
||||
writeln!(
|
||||
file,
|
||||
"shared_preload_libraries='{}{}'",
|
||||
libs, extra_shared_preload_libraries
|
||||
)?;
|
||||
|
||||
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
|
||||
let mut extra_shared_preload_libraries = String::new();
|
||||
if !libs.contains("pgaudit") {
|
||||
extra_shared_preload_libraries.push_str(",pgaudit");
|
||||
}
|
||||
writeln!(
|
||||
file,
|
||||
"shared_preload_libraries='{}{}'",
|
||||
libs, extra_shared_preload_libraries
|
||||
)?;
|
||||
} else {
|
||||
// Typically, this should be unreacheable,
|
||||
// because we always set at least some shared_preload_libraries in the spec
|
||||
// but let's handle it explicitly anyway.
|
||||
writeln!(file, "shared_preload_libraries='neon,pgaudit'")?;
|
||||
}
|
||||
writeln!(file, "# Managed by compute_ctl base audit settings: end")?;
|
||||
}
|
||||
ComputeAudit::Hipaa | ComputeAudit::Extended | ComputeAudit::Full => {
|
||||
@@ -250,15 +228,28 @@ pub fn write_postgres_conf(
|
||||
// The caller who sets the flag is responsible for ensuring that the necessary
|
||||
// shared_preload_libraries are present in the compute image,
|
||||
// otherwise the compute start will fail.
|
||||
if !libs.contains("pgauditlogtofile") {
|
||||
extra_shared_preload_libraries.push_str(",pgauditlogtofile");
|
||||
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
|
||||
let mut extra_shared_preload_libraries = String::new();
|
||||
if !libs.contains("pgaudit") {
|
||||
extra_shared_preload_libraries.push_str(",pgaudit");
|
||||
}
|
||||
if !libs.contains("pgauditlogtofile") {
|
||||
extra_shared_preload_libraries.push_str(",pgauditlogtofile");
|
||||
}
|
||||
writeln!(
|
||||
file,
|
||||
"shared_preload_libraries='{}{}'",
|
||||
libs, extra_shared_preload_libraries
|
||||
)?;
|
||||
} else {
|
||||
// Typically, this should be unreacheable,
|
||||
// because we always set at least some shared_preload_libraries in the spec
|
||||
// but let's handle it explicitly anyway.
|
||||
writeln!(
|
||||
file,
|
||||
"shared_preload_libraries='neon,pgaudit,pgauditlogtofile'"
|
||||
)?;
|
||||
}
|
||||
writeln!(
|
||||
file,
|
||||
"shared_preload_libraries='{}{}'",
|
||||
libs, extra_shared_preload_libraries
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
file,
|
||||
"# Managed by compute_ctl compliance audit settings: end"
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
use std::fs::File;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
|
||||
#[cfg(target_os = "linux")]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
|
||||
|
||||
@@ -97,7 +99,7 @@ impl VirtualFile {
|
||||
|
||||
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
#[cfg_attr(not(target_os = "linux"), allow(unused_mut))] mut open_options: OpenOptions,
|
||||
open_options: &OpenOptions,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let mode = get_io_mode();
|
||||
@@ -110,16 +112,21 @@ impl VirtualFile {
|
||||
#[cfg(target_os = "linux")]
|
||||
(IoMode::DirectRw, _) => true,
|
||||
};
|
||||
if set_o_direct {
|
||||
let open_options = open_options.clone();
|
||||
let open_options = if set_o_direct {
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
open_options = open_options.custom_flags(nix::libc::O_DIRECT);
|
||||
let mut open_options = open_options;
|
||||
open_options.custom_flags(nix::libc::O_DIRECT);
|
||||
open_options
|
||||
}
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
unreachable!(
|
||||
"O_DIRECT is not supported on this platform, IoMode's that result in set_o_direct=true shouldn't even be defined"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
open_options
|
||||
};
|
||||
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
|
||||
Ok(VirtualFile { inner, _mode: mode })
|
||||
}
|
||||
@@ -523,7 +530,7 @@ impl VirtualFileInner {
|
||||
path: P,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<VirtualFileInner, std::io::Error> {
|
||||
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
|
||||
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true).clone(), ctx).await
|
||||
}
|
||||
|
||||
/// Open a file with given options.
|
||||
@@ -551,11 +558,10 @@ impl VirtualFileInner {
|
||||
// It would perhaps be nicer to check just for the read and write flags
|
||||
// explicitly, but OpenOptions doesn't contain any functions to read flags,
|
||||
// only to set them.
|
||||
let reopen_options = open_options
|
||||
.clone()
|
||||
.create(false)
|
||||
.create_new(false)
|
||||
.truncate(false);
|
||||
let mut reopen_options = open_options.clone();
|
||||
reopen_options.create(false);
|
||||
reopen_options.create_new(false);
|
||||
reopen_options.truncate(false);
|
||||
|
||||
let vfile = VirtualFileInner {
|
||||
handle: RwLock::new(handle),
|
||||
@@ -1301,7 +1307,7 @@ mod tests {
|
||||
opts: OpenOptions,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<MaybeVirtualFile, anyhow::Error> {
|
||||
let vf = VirtualFile::open_with_options_v2(&path, opts, ctx).await?;
|
||||
let vf = VirtualFile::open_with_options_v2(&path, &opts, ctx).await?;
|
||||
Ok(MaybeVirtualFile::VirtualFile(vf))
|
||||
}
|
||||
}
|
||||
@@ -1368,7 +1374,7 @@ mod tests {
|
||||
let _ = file_a.read_string_at(0, 1, &ctx).await.unwrap_err();
|
||||
|
||||
// Close the file and re-open for reading
|
||||
let mut file_a = A::open(path_a, OpenOptions::new().read(true), &ctx).await?;
|
||||
let mut file_a = A::open(path_a, OpenOptions::new().read(true).to_owned(), &ctx).await?;
|
||||
|
||||
// cannot write to a file opened in read-only mode
|
||||
let _ = file_a
|
||||
@@ -1387,7 +1393,8 @@ mod tests {
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true),
|
||||
.truncate(true)
|
||||
.to_owned(),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
@@ -1405,7 +1412,12 @@ mod tests {
|
||||
|
||||
let mut vfiles = Vec::new();
|
||||
for _ in 0..100 {
|
||||
let mut vfile = A::open(path_b.clone(), OpenOptions::new().read(true), &ctx).await?;
|
||||
let mut vfile = A::open(
|
||||
path_b.clone(),
|
||||
OpenOptions::new().read(true).to_owned(),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!("FOOBAR", vfile.read_string_at(0, 6, &ctx).await?);
|
||||
vfiles.push(vfile);
|
||||
}
|
||||
@@ -1454,7 +1466,7 @@ mod tests {
|
||||
for _ in 0..VIRTUAL_FILES {
|
||||
let f = VirtualFileInner::open_with_options(
|
||||
&test_file_path,
|
||||
OpenOptions::new().read(true),
|
||||
OpenOptions::new().read(true).clone(),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
//! Enum-dispatch to the `OpenOptions` type of the respective [`super::IoEngineKind`];
|
||||
|
||||
use std::os::fd::OwnedFd;
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::path::Path;
|
||||
|
||||
use super::io_engine::IoEngine;
|
||||
@@ -44,7 +43,7 @@ impl OpenOptions {
|
||||
self.write
|
||||
}
|
||||
|
||||
pub fn read(mut self, read: bool) -> Self {
|
||||
pub fn read(&mut self, read: bool) -> &mut OpenOptions {
|
||||
match &mut self.inner {
|
||||
Inner::StdFs(x) => {
|
||||
let _ = x.read(read);
|
||||
@@ -57,7 +56,7 @@ impl OpenOptions {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn write(mut self, write: bool) -> Self {
|
||||
pub fn write(&mut self, write: bool) -> &mut OpenOptions {
|
||||
self.write = write;
|
||||
match &mut self.inner {
|
||||
Inner::StdFs(x) => {
|
||||
@@ -71,7 +70,7 @@ impl OpenOptions {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn create(mut self, create: bool) -> Self {
|
||||
pub fn create(&mut self, create: bool) -> &mut OpenOptions {
|
||||
match &mut self.inner {
|
||||
Inner::StdFs(x) => {
|
||||
let _ = x.create(create);
|
||||
@@ -84,7 +83,7 @@ impl OpenOptions {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn create_new(mut self, create_new: bool) -> Self {
|
||||
pub fn create_new(&mut self, create_new: bool) -> &mut OpenOptions {
|
||||
match &mut self.inner {
|
||||
Inner::StdFs(x) => {
|
||||
let _ = x.create_new(create_new);
|
||||
@@ -97,7 +96,7 @@ impl OpenOptions {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn truncate(mut self, truncate: bool) -> Self {
|
||||
pub fn truncate(&mut self, truncate: bool) -> &mut OpenOptions {
|
||||
match &mut self.inner {
|
||||
Inner::StdFs(x) => {
|
||||
let _ = x.truncate(truncate);
|
||||
@@ -125,8 +124,10 @@ impl OpenOptions {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mode(mut self, mode: u32) -> Self {
|
||||
impl std::os::unix::prelude::OpenOptionsExt for OpenOptions {
|
||||
fn mode(&mut self, mode: u32) -> &mut OpenOptions {
|
||||
match &mut self.inner {
|
||||
Inner::StdFs(x) => {
|
||||
let _ = x.mode(mode);
|
||||
@@ -139,7 +140,7 @@ impl OpenOptions {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn custom_flags(mut self, flags: i32) -> Self {
|
||||
fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions {
|
||||
match &mut self.inner {
|
||||
Inner::StdFs(x) => {
|
||||
let _ = x.custom_flags(flags);
|
||||
|
||||
@@ -32,6 +32,12 @@ pub(crate) enum ComputeUserInfoParseError {
|
||||
option: EndpointId,
|
||||
},
|
||||
|
||||
#[error(
|
||||
"Common name inferred from SNI ('{}') is not known",
|
||||
.cn,
|
||||
)]
|
||||
UnknownCommonName { cn: String },
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
||||
MalformedProjectName(EndpointId),
|
||||
}
|
||||
@@ -60,15 +66,22 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
|
||||
let (subdomain, common_name) = sni.split_once('.')?;
|
||||
pub(crate) fn endpoint_sni(
|
||||
sni: &str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
|
||||
let Some((subdomain, common_name)) = sni.split_once('.') else {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
|
||||
};
|
||||
if !common_names.contains(common_name) {
|
||||
return None;
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName {
|
||||
cn: common_name.into(),
|
||||
});
|
||||
}
|
||||
if subdomain == SERVERLESS_DRIVER_SNI {
|
||||
return None;
|
||||
return Ok(None);
|
||||
}
|
||||
Some(EndpointId::from(subdomain))
|
||||
Ok(Some(EndpointId::from(subdomain)))
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
@@ -100,8 +113,15 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
})
|
||||
.map(|name| name.into());
|
||||
|
||||
let endpoint_from_domain =
|
||||
sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
|
||||
let endpoint_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
endpoint_sni(sni_str, cn)?
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let endpoint = match (endpoint_option, endpoint_from_domain) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
@@ -404,34 +424,21 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unknown_sni() {
|
||||
fn parse_inconsistent_sni() {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.unwrap();
|
||||
|
||||
assert!(info.endpoint_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unknown_sni_with_options() {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("options", "endpoint=foo-bar-baz-1234"),
|
||||
]);
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
|
||||
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
UnknownCommonName { cn } => {
|
||||
assert_eq!(cn, "localhost");
|
||||
}
|
||||
_ => panic!("bad error: {err:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -24,6 +24,9 @@ pub(crate) enum HandshakeError {
|
||||
#[error("protocol violation")]
|
||||
ProtocolViolation,
|
||||
|
||||
#[error("missing certificate")]
|
||||
MissingCertificate,
|
||||
|
||||
#[error("{0}")]
|
||||
StreamUpgradeError(#[from] StreamUpgradeError),
|
||||
|
||||
@@ -39,6 +42,10 @@ impl ReportableError for HandshakeError {
|
||||
match self {
|
||||
HandshakeError::EarlyData => crate::error::ErrorKind::User,
|
||||
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
|
||||
// This error should not happen, but will if we have no default certificate and
|
||||
// the client sends no SNI extension.
|
||||
// If they provide SNI then we can be sure there is a certificate that matches.
|
||||
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
|
||||
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
|
||||
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
|
||||
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
@@ -139,7 +146,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// try parse endpoint
|
||||
let ep = conn_info
|
||||
.server_name()
|
||||
.and_then(|sni| endpoint_sni(sni, &tls.common_names));
|
||||
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
|
||||
if let Some(ep) = ep {
|
||||
ctx.set_endpoint_id(ep);
|
||||
}
|
||||
@@ -154,8 +161,10 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(conn_info.server_name())
|
||||
.ok_or(HandshakeError::MissingCertificate)?;
|
||||
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
|
||||
@@ -98,7 +98,8 @@ fn generate_tls_config<'a>(
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert.clone()], key.clone_key())?;
|
||||
|
||||
let cert_resolver = CertResolver::new(key, vec![cert])?;
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
|
||||
@@ -199,7 +199,8 @@ fn get_conn_info(
|
||||
let endpoint = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => {
|
||||
if let Some(tls) = tls {
|
||||
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
|
||||
endpoint_sni(hostname, &tls.common_names)?
|
||||
.ok_or(ConnInfoError::MalformedEndpoint)?
|
||||
} else {
|
||||
hostname
|
||||
.split_once('.')
|
||||
|
||||
@@ -5,7 +5,6 @@ use anyhow::{Context, bail};
|
||||
use itertools::Itertools;
|
||||
use rustls::crypto::ring::{self, sign};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::sign::CertifiedKey;
|
||||
use x509_cert::der::{Reader, SliceReader};
|
||||
|
||||
use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
|
||||
@@ -26,8 +25,10 @@ pub fn configure_tls(
|
||||
certs_dir: Option<&String>,
|
||||
allow_tls_keylogfile: bool,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
@@ -39,8 +40,11 @@ pub fn configure_tls(
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver
|
||||
.add_cert_path(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?;
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -79,42 +83,92 @@ pub fn configure_tls(
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint),
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
fn parse_new(key_path: &str, cert_path: &str) -> anyhow::Result<Self> {
|
||||
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
|
||||
Self::new(priv_key, cert_chain)
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
) -> anyhow::Result<Self> {
|
||||
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
|
||||
let mut certs = HashMap::new();
|
||||
let default = (cert.clone(), tls_server_end_point);
|
||||
certs.insert(common_name, (cert, tls_server_end_point));
|
||||
Ok(Self { certs, default })
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
fn add_cert_path(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> {
|
||||
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
|
||||
self.add_cert(priv_key, cert_chain)
|
||||
}
|
||||
|
||||
fn add_cert(
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let certificate = SliceReader::new(first_cert)
|
||||
.context("Failed to parse cerficiate")?
|
||||
.decode::<x509_cert::Certificate>()
|
||||
.context("Failed to parse cerficiate")?;
|
||||
|
||||
let common_name = certificate.tbs_certificate.subject.to_string();
|
||||
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -123,82 +177,12 @@ impl CertResolver {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_key_cert(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
) -> anyhow::Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>)> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
Ok((priv_key, cert_chain))
|
||||
}
|
||||
|
||||
fn process_key_cert(
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
) -> anyhow::Result<(String, Arc<CertifiedKey>, TlsServerEndPoint)> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let certificate = SliceReader::new(first_cert)
|
||||
.context("Failed to parse cerficiate")?
|
||||
.decode::<x509_cert::Certificate>()
|
||||
.context("Failed to parse cerficiate")?;
|
||||
|
||||
let common_name = certificate.tbs_certificate.subject.to_string();
|
||||
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
Ok((common_name, cert, tls_server_end_point))
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
Some(self.resolve(client_hello.server_name()).0)
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +190,7 @@ impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint) {
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
@@ -216,17 +200,12 @@ impl CertResolver {
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return cert.clone();
|
||||
return Some(cert.clone());
|
||||
}
|
||||
if let Some((_, rest)) = sni_name.split_once('.') {
|
||||
sni_name = rest;
|
||||
} else {
|
||||
// The customer has some custom DNS mapping - just return
|
||||
// a default certificate.
|
||||
//
|
||||
// This will error if the customer uses anything stronger
|
||||
// than sslmode=require. That's a choice they can make.
|
||||
return self.default.clone();
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -202,8 +202,6 @@ def test_pageserver_gc_compaction_preempt(
|
||||
env = neon_env_builder.init_start(initial_tenant_conf=conf)
|
||||
|
||||
env.pageserver.allowed_errors.append(".*The timeline or pageserver is shutting down.*")
|
||||
env.pageserver.allowed_errors.append(".*flush task cancelled.*")
|
||||
env.pageserver.allowed_errors.append(".*failed to pipe.*")
|
||||
|
||||
tenant_id = env.initial_tenant
|
||||
timeline_id = env.initial_timeline
|
||||
|
||||
11
test_runner/regress/test_flaky.py
Normal file
11
test_runner/regress/test_flaky.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Test for detecting new flaky tests"""
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def test_flaky1():
|
||||
assert random.random() > 0.05
|
||||
|
||||
|
||||
def no_test_flaky2():
|
||||
assert random.random() > 0.05
|
||||
@@ -11,6 +11,9 @@ if TYPE_CHECKING:
|
||||
# Test that pageserver and safekeeper can restart quickly.
|
||||
# This is a regression test, see https://github.com/neondatabase/neon/issues/2247
|
||||
def test_fixture_restart(neon_env_builder: NeonEnvBuilder):
|
||||
import random
|
||||
|
||||
assert random.random() > 0.05
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
for _ in range(3):
|
||||
@@ -20,3 +23,9 @@ def test_fixture_restart(neon_env_builder: NeonEnvBuilder):
|
||||
for _ in range(3):
|
||||
env.safekeepers[0].stop()
|
||||
env.safekeepers[0].start()
|
||||
|
||||
|
||||
def test_flaky3():
|
||||
import random
|
||||
|
||||
assert random.random() > 0.05
|
||||
|
||||
Reference in New Issue
Block a user