Compare commits

..

6 Commits

Author SHA1 Message Date
Alexander Lakhin
eba6f85909 Add flaky tests to test how testing for flaky tests works 2025-05-14 11:31:45 +03:00
Alexander Lakhin
52888e06a0 Run testing of new tests within "Build and Test" workflow 2025-05-14 11:31:44 +03:00
Alexander Lakhin
2b51b7cbb1 Transform "Build and Run Selected Test" workflow into "Build and Test Tests" 2025-05-14 11:31:44 +03:00
Alexander Lakhin
5895abf495 Allow for adding optional postfix to allure report name 2025-05-14 11:31:44 +03:00
Alexander Lakhin
05559539f1 Adjust regress-tests step to pass list of selected tests 2025-05-14 11:31:43 +03:00
Alexander Lakhin
8f2808864b Allow for running multiple selected tests 2025-05-14 11:31:43 +03:00
18 changed files with 452 additions and 229 deletions

View File

@@ -43,7 +43,7 @@ runs:
BUCKET: neon-github-public-dev
run: |
if [ -n "${PR_NUMBER}" ]; then
BRANCH_OR_PR=pr-${PR_NUMBER}
BRANCH_OR_PR=pr-${PR_NUMBER}${REPORT_EXT-}
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
# Shortcut for special branches

View File

@@ -23,7 +23,7 @@ runs:
REPORT_DIR: ${{ inputs.report-dir }}
run: |
if [ -n "${PR_NUMBER}" ]; then
BRANCH_OR_PR=pr-${PR_NUMBER}
BRANCH_OR_PR=pr-${PR_NUMBER}${REPORT_EXT-}
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
# Shortcut for special branches

View File

@@ -12,6 +12,10 @@ inputs:
description: 'Arbitrary parameters to pytest. For example "-s" to prevent capturing stdout/stderr'
required: false
default: ''
extended_testing:
description: 'Set to true if the test results should be stored and processed separately'
required: false
default: 'false'
needs_postgres_source:
description: 'Set to true if the test suite requires postgres source checked out'
required: false
@@ -135,13 +139,15 @@ runs:
PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)"
echo "PERF_REPORT_DIR=${PERF_REPORT_DIR}" >> ${GITHUB_ENV}
rm -rf $PERF_REPORT_DIR
TEST_SELECTION="test_runner/${{ inputs.test_selection }}"
EXTRA_PARAMS="${{ inputs.extra_params }}"
TEST_SELECTION="${{ inputs.test_selection }}"
if [ -z "$TEST_SELECTION" ]; then
echo "test_selection must be set"
exit 1
fi
if [[ $TEST_SELECTION != test_runner/* ]]; then
TEST_SELECTION="test_runner/$TEST_SELECTION"
fi
EXTRA_PARAMS="${{ inputs.extra_params }}"
if [[ "${{ inputs.run_in_parallel }}" == "true" ]]; then
# -n sets the number of parallel processes that pytest-xdist will run
EXTRA_PARAMS="-n12 $EXTRA_PARAMS"
@@ -244,3 +250,5 @@ runs:
report-dir: /tmp/test_output/allure/results
unique-key: ${{ inputs.build_type }}-${{ inputs.pg_version }}-${{ runner.arch }}
aws-oidc-role-arn: ${{ inputs.aws-oidc-role-arn }}
env:
REPORT_EXT: ${{ inputs.extended_testing == 'true' && '-ext' || '' }}

143
.github/scripts/detect-updated-pytests.py vendored Executable file
View File

@@ -0,0 +1,143 @@
import os
import re
import shutil
import subprocess
import sys
commit_sha = os.getenv("COMMIT_SHA")
base_sha = os.getenv("BASE_SHA")
cmd = ["git", "merge-base", base_sha, commit_sha]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if result.returncode != 0 or not (baseline := result.stdout.strip()):
print("Baseline commit for PR is not found, detection skipped.")
sys.exit(0)
print(f"Baseline commit: {baseline}")
cmd = ["git", "diff", "--name-only", f"{baseline}..{commit_sha}", "test_runner/regress/"]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if result.returncode != 0:
print(f"Git diff returned code {result.returncode}\n{result.stdout}\nDetection skipped.")
sys.exit(0)
def collect_tests(test_file_name):
cmd = ["./scripts/pytest", "--collect-only", "-q", test_file_name]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if result.returncode != 0:
print(
f"pytest --collect-only returned code {result.returncode}\n{result.stdout}\nDetection skipped."
)
sys.exit(0)
tests = []
for test_item in result.stdout.split("\n"):
if not test_item.startswith(test_file_name):
break
test_name = re.sub(r"(.*::)([^\[]+)(\[.*)", r"\2", test_item)
if test_name not in tests:
tests.append(test_name)
return tests
all_new_tests = []
all_updated_tests = []
temp_test_file = "test_runner/regress/__temp__.py"
temp_file = None
for test_file in result.stdout.split("\n"):
if not test_file:
continue
print(f"Test file modified: {test_file}.")
# Get and compare two lists of items collected by pytest to detect new tests in the PR
if temp_file:
temp_file.close()
temp_file = open(temp_test_file, "w")
cmd = ["git", "show", f"{baseline}:{test_file}"]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(cmd, text=True, stdout=temp_file)
if result.returncode != 0:
tests0 = []
else:
tests0 = collect_tests(temp_test_file)
tests1 = collect_tests(test_file)
new_tests = set(tests1).difference(tests0)
for test_name in new_tests:
all_new_tests.append(f"{test_file}::{test_name}")
# Detect pre-existing test functions updated in the PR
cmd = ["git", "diff", f"{baseline}..{commit_sha}", test_file]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if result.returncode != 0:
print(f"Git diff returned code {result.returncode}\n{result.stdout}\nDetection skipped.")
sys.exit(0)
updated_funcs = []
for diff_line in result.stdout.split("\n"):
print(diff_line)
# TODO: detect functions with added/modified parameters
if not diff_line.startswith("@@"):
continue
# Extract names of functions with updated content relying on hunk header
m = re.match(r"^(@@[0-9, +-]+@@ def )([^(]+)(.*)", diff_line)
if not m:
continue
func_name = m.group(2)
print(func_name) ##
# Ignore functions not collected by pytest
if func_name not in tests1:
continue
if func_name not in updated_funcs:
updated_funcs.append(func_name)
for func_name in updated_funcs:
print(f"Function modified: {func_name}.")
# Extract changes within the function
cmd = ["git", "log", f"{baseline}..{commit_sha}", "-L", f":{func_name}:{test_file}"]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if result.returncode != 0:
continue
patch_contents = result.stdout
# Revert changes to get the file with only this function updated
# (applying the patch might fail if it contains a change for the next function declaraion)
shutil.copy(test_file, temp_test_file)
cmd = ["patch", "-R", "-p1", "--no-backup-if-mismatch", "-r", "/dev/null", temp_test_file]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(
cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, input=patch_contents
)
print(f"result: {result.returncode}; {result.stdout}")
if result.returncode != 0:
continue
# Ignore whitespace-only changes
cmd = ["diff", "-w", test_file, temp_test_file]
print(f"Running: {' '.join(cmd)}...")
result = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if result.returncode == 0:
continue
all_updated_tests.append(f"{test_file}::{func_name}")
if temp_file:
temp_file.close()
if os.path.exists(temp_test_file):
os.remove(temp_test_file)
if github_output := os.getenv("GITHUB_OUTPUT"):
with open(github_output, "a") as f:
if all_new_tests or all_updated_tests:
f.write("tests=")
f.write(" ".join(all_new_tests + all_updated_tests))
f.write("\n")

View File

@@ -266,7 +266,7 @@ jobs:
role-duration-seconds: 18000 # 5 hours
- name: Run rust tests
if: ${{ inputs.sanitizers != 'enabled' }}
if: ${{ inputs.sanitizers != 'enabled' && inputs.test-selection == '' }}
env:
NEXTEST_RETRIES: 3
run: |
@@ -386,7 +386,8 @@ jobs:
timeout-minutes: ${{ inputs.sanitizers != 'enabled' && 75 || 180 }}
with:
build_type: ${{ inputs.build-type }}
test_selection: regress
test_selection: ${{ inputs.test-selection != '' && inputs.test-selection || 'regress' }}
extended_testing: ${{ inputs.test-selection != '' && 'true' || 'false' }}
needs_postgres_source: true
run_with_real_s3: true
real_s3_bucket: neon-github-ci-tests
@@ -399,9 +400,7 @@ jobs:
# Attempt to stop tests gracefully to generate test reports
# until they are forcibly stopped by the stricter `timeout-minutes` limit.
extra_params: --session-timeout=${{ inputs.sanitizers != 'enabled' && 3000 || 10200 }} --count=${{ inputs.test-run-count }}
${{ inputs.test-selection != '' && format('-k "{0}"', inputs.test-selection) || '' }}
env:
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
BUILD_TAG: ${{ inputs.build-tag }}
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring

View File

@@ -199,6 +199,12 @@ jobs:
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
secrets: inherit
build-and-test-new-tests:
needs: [ meta, build-build-tools-image ]
if: github.event_name == 'pull_request'
uses: ./.github/workflows/build_and_test_tests.yml
secrets: inherit
build-and-test-locally:
needs: [ meta, build-build-tools-image ]
# We do need to run this in `.*-rc-pr` because of hotfixes.

View File

@@ -1,10 +1,10 @@
name: Build and Run Selected Test
name: Build and Run Selected Tests
on:
workflow_dispatch:
inputs:
test-selection:
description: 'Specification of selected test(s), as accepted by pytest -k'
description: 'Specification of selected test(s), e. g.: test_runner/regress/test_pg_regress.py::test_pg_regress'
required: true
type: string
run-count:
@@ -26,6 +26,8 @@ on:
default: '[{"pg_version":"v17"}]'
required: true
type: string
workflow_call:
pull_request: # TODO: remove before merge
defaults:
run:
@@ -42,26 +44,71 @@ jobs:
github-event-name: ${{ github.event_name }}
github-event-json: ${{ toJSON(github.event) }}
build-and-test-locally:
needs: [ meta ]
choose-test-parameters:
runs-on: [ self-hosted, small ]
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
outputs:
tests: ${{ inputs.test-selection != '' && inputs.test-selection || steps.detect_tests_to_test.outputs.tests }}
archs: ${{ inputs.test-selection != '' && inputs.archs || '["x64", "arm64"]' }}
build-types: ${{ inputs.test-selection != '' && inputs.build-types || '["release"]' }}
pg-versions: ${{ inputs.test-selection != '' && inputs.pg-versions || '[{"pg_version":"v14"}, {"pg_version":"v17"}]' }}
run-count: ${{ inputs.test-selection != '' && inputs.run-count || 5 }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
if: inputs.test-selection == ''
with:
submodules: false
clean: false
fetch-depth: 1000
- name: Cache poetry deps
if: inputs.test-selection == ''
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry/virtualenvs
key: v2-${{ runner.os }}-${{ runner.arch }}-python-deps-bookworm-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
if: inputs.test-selection == ''
shell: bash -euxo pipefail {0}
run: ./scripts/pysync
- name: Detect new and updated tests
id: detect_tests_to_test
if: github.event.pull_request.head.sha && inputs.test-selection == ''
env:
COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }}
run: python3 .github/scripts/detect-updated-pytests.py
build-and-test-tests:
needs: [ meta, choose-test-parameters ]
if: needs.choose-test-parameters.outputs.tests != ''
strategy:
fail-fast: false
matrix:
arch: ${{ fromJson(inputs.archs) }}
build-type: ${{ fromJson(inputs.build-types) }}
arch: ${{ fromJson(needs.choose-test-parameters.outputs.archs) }}
build-type: ${{ fromJson(needs.choose-test-parameters.outputs.build-types) }}
uses: ./.github/workflows/_build-and-test-locally.yml
with:
arch: ${{ matrix.arch }}
build-tools-image: ghcr.io/neondatabase/build-tools:pinned-bookworm
build-tag: ${{ needs.meta.outputs.build-tag }}
build-type: ${{ matrix.build-type }}
test-cfg: ${{ inputs.pg-versions }}
test-selection: ${{ inputs.test-selection }}
test-run-count: ${{ fromJson(inputs.run-count) }}
test-cfg: ${{ needs.choose-test-parameters.outputs.pg-versions }}
test-selection: ${{ needs.choose-test-parameters.outputs.tests }}
test-run-count: ${{ fromJson(needs.choose-test-parameters.outputs.run-count) }}
secrets: inherit
create-test-report:
needs: [ build-and-test-locally ]
needs: [ build-and-test-tests ]
if: ${{ !cancelled() }}
permissions:
id-token: write # aws-actions/configure-aws-credentials
@@ -96,6 +143,7 @@ jobs:
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_DEV }}
REPORT_EXT: '-ext'
- uses: actions/github-script@v7
if: ${{ !cancelled() }}

View File

@@ -168,35 +168,6 @@ pub fn write_postgres_conf(
writeln!(file, "# Managed by compute_ctl: end")?;
}
// Always add pgaudit to shared_preload_libraries.
//
// This is needed to handle the downgrade scenario.
// pgaudit extension creates event triggers that require library to be loaded.
// so, once extension was installed it must always be present in shared_preload_libraries.
let mut extra_shared_preload_libraries = String::new();
let libs = {
// We don't distribute pgaudit in the testing image,
// and don't pass shared_preload_libraries via spec,
// so disable this logic there.
#[cfg(feature = "testing")]
{
String::new()
}
#[cfg(not(feature = "testing"))]
{
spec.cluster
.settings
.find("shared_preload_libraries")
.expect("shared_preload_libraries setting is missing in the spec")
}
};
#[cfg(not(feature = "testing"))]
if !libs.contains("pgaudit") {
extra_shared_preload_libraries.push_str(",pgaudit");
};
// If base audit logging is enabled, configure it.
// In this setup, the audit log will be written to the standard postgresql log.
//
@@ -206,22 +177,29 @@ pub fn write_postgres_conf(
// This way we always override the settings from the spec
// and don't allow the user or the control plane admin to change them.
match spec.audit_log_level {
ComputeAudit::Disabled => {
// this is the default, but let's be explicit
writeln!(file, "pgaudit.log='none'")?;
}
ComputeAudit::Disabled => {}
ComputeAudit::Log | ComputeAudit::Base => {
writeln!(file, "# Managed by compute_ctl base audit settings: start")?;
writeln!(file, "pgaudit.log='ddl,role'")?;
// Disable logging of catalog queries to reduce the noise
writeln!(file, "pgaudit.log_catalog=off")?;
writeln!(
file,
"shared_preload_libraries='{}{}'",
libs, extra_shared_preload_libraries
)?;
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
let mut extra_shared_preload_libraries = String::new();
if !libs.contains("pgaudit") {
extra_shared_preload_libraries.push_str(",pgaudit");
}
writeln!(
file,
"shared_preload_libraries='{}{}'",
libs, extra_shared_preload_libraries
)?;
} else {
// Typically, this should be unreacheable,
// because we always set at least some shared_preload_libraries in the spec
// but let's handle it explicitly anyway.
writeln!(file, "shared_preload_libraries='neon,pgaudit'")?;
}
writeln!(file, "# Managed by compute_ctl base audit settings: end")?;
}
ComputeAudit::Hipaa | ComputeAudit::Extended | ComputeAudit::Full => {
@@ -250,15 +228,28 @@ pub fn write_postgres_conf(
// The caller who sets the flag is responsible for ensuring that the necessary
// shared_preload_libraries are present in the compute image,
// otherwise the compute start will fail.
if !libs.contains("pgauditlogtofile") {
extra_shared_preload_libraries.push_str(",pgauditlogtofile");
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
let mut extra_shared_preload_libraries = String::new();
if !libs.contains("pgaudit") {
extra_shared_preload_libraries.push_str(",pgaudit");
}
if !libs.contains("pgauditlogtofile") {
extra_shared_preload_libraries.push_str(",pgauditlogtofile");
}
writeln!(
file,
"shared_preload_libraries='{}{}'",
libs, extra_shared_preload_libraries
)?;
} else {
// Typically, this should be unreacheable,
// because we always set at least some shared_preload_libraries in the spec
// but let's handle it explicitly anyway.
writeln!(
file,
"shared_preload_libraries='neon,pgaudit,pgauditlogtofile'"
)?;
}
writeln!(
file,
"shared_preload_libraries='{}{}'",
libs, extra_shared_preload_libraries
)?;
writeln!(
file,
"# Managed by compute_ctl compliance audit settings: end"

View File

@@ -14,6 +14,8 @@
use std::fs::File;
use std::io::{Error, ErrorKind};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
#[cfg(target_os = "linux")]
use std::os::unix::fs::OpenOptionsExt;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
@@ -97,7 +99,7 @@ impl VirtualFile {
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
path: P,
#[cfg_attr(not(target_os = "linux"), allow(unused_mut))] mut open_options: OpenOptions,
open_options: &OpenOptions,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let mode = get_io_mode();
@@ -110,16 +112,21 @@ impl VirtualFile {
#[cfg(target_os = "linux")]
(IoMode::DirectRw, _) => true,
};
if set_o_direct {
let open_options = open_options.clone();
let open_options = if set_o_direct {
#[cfg(target_os = "linux")]
{
open_options = open_options.custom_flags(nix::libc::O_DIRECT);
let mut open_options = open_options;
open_options.custom_flags(nix::libc::O_DIRECT);
open_options
}
#[cfg(not(target_os = "linux"))]
unreachable!(
"O_DIRECT is not supported on this platform, IoMode's that result in set_o_direct=true shouldn't even be defined"
);
}
} else {
open_options
};
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(VirtualFile { inner, _mode: mode })
}
@@ -523,7 +530,7 @@ impl VirtualFileInner {
path: P,
ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true).clone(), ctx).await
}
/// Open a file with given options.
@@ -551,11 +558,10 @@ impl VirtualFileInner {
// It would perhaps be nicer to check just for the read and write flags
// explicitly, but OpenOptions doesn't contain any functions to read flags,
// only to set them.
let reopen_options = open_options
.clone()
.create(false)
.create_new(false)
.truncate(false);
let mut reopen_options = open_options.clone();
reopen_options.create(false);
reopen_options.create_new(false);
reopen_options.truncate(false);
let vfile = VirtualFileInner {
handle: RwLock::new(handle),
@@ -1301,7 +1307,7 @@ mod tests {
opts: OpenOptions,
ctx: &RequestContext,
) -> Result<MaybeVirtualFile, anyhow::Error> {
let vf = VirtualFile::open_with_options_v2(&path, opts, ctx).await?;
let vf = VirtualFile::open_with_options_v2(&path, &opts, ctx).await?;
Ok(MaybeVirtualFile::VirtualFile(vf))
}
}
@@ -1368,7 +1374,7 @@ mod tests {
let _ = file_a.read_string_at(0, 1, &ctx).await.unwrap_err();
// Close the file and re-open for reading
let mut file_a = A::open(path_a, OpenOptions::new().read(true), &ctx).await?;
let mut file_a = A::open(path_a, OpenOptions::new().read(true).to_owned(), &ctx).await?;
// cannot write to a file opened in read-only mode
let _ = file_a
@@ -1387,7 +1393,8 @@ mod tests {
.read(true)
.write(true)
.create(true)
.truncate(true),
.truncate(true)
.to_owned(),
&ctx,
)
.await?;
@@ -1405,7 +1412,12 @@ mod tests {
let mut vfiles = Vec::new();
for _ in 0..100 {
let mut vfile = A::open(path_b.clone(), OpenOptions::new().read(true), &ctx).await?;
let mut vfile = A::open(
path_b.clone(),
OpenOptions::new().read(true).to_owned(),
&ctx,
)
.await?;
assert_eq!("FOOBAR", vfile.read_string_at(0, 6, &ctx).await?);
vfiles.push(vfile);
}
@@ -1454,7 +1466,7 @@ mod tests {
for _ in 0..VIRTUAL_FILES {
let f = VirtualFileInner::open_with_options(
&test_file_path,
OpenOptions::new().read(true),
OpenOptions::new().read(true).clone(),
&ctx,
)
.await?;

View File

@@ -1,7 +1,6 @@
//! Enum-dispatch to the `OpenOptions` type of the respective [`super::IoEngineKind`];
use std::os::fd::OwnedFd;
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
use super::io_engine::IoEngine;
@@ -44,7 +43,7 @@ impl OpenOptions {
self.write
}
pub fn read(mut self, read: bool) -> Self {
pub fn read(&mut self, read: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.read(read);
@@ -57,7 +56,7 @@ impl OpenOptions {
self
}
pub fn write(mut self, write: bool) -> Self {
pub fn write(&mut self, write: bool) -> &mut OpenOptions {
self.write = write;
match &mut self.inner {
Inner::StdFs(x) => {
@@ -71,7 +70,7 @@ impl OpenOptions {
self
}
pub fn create(mut self, create: bool) -> Self {
pub fn create(&mut self, create: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create(create);
@@ -84,7 +83,7 @@ impl OpenOptions {
self
}
pub fn create_new(mut self, create_new: bool) -> Self {
pub fn create_new(&mut self, create_new: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create_new(create_new);
@@ -97,7 +96,7 @@ impl OpenOptions {
self
}
pub fn truncate(mut self, truncate: bool) -> Self {
pub fn truncate(&mut self, truncate: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.truncate(truncate);
@@ -125,8 +124,10 @@ impl OpenOptions {
}
}
}
}
pub fn mode(mut self, mode: u32) -> Self {
impl std::os::unix::prelude::OpenOptionsExt for OpenOptions {
fn mode(&mut self, mode: u32) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.mode(mode);
@@ -139,7 +140,7 @@ impl OpenOptions {
self
}
pub fn custom_flags(mut self, flags: i32) -> Self {
fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.custom_flags(flags);

View File

@@ -32,6 +32,12 @@ pub(crate) enum ComputeUserInfoParseError {
option: EndpointId,
},
#[error(
"Common name inferred from SNI ('{}') is not known",
.cn,
)]
UnknownCommonName { cn: String },
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
MalformedProjectName(EndpointId),
}
@@ -60,15 +66,22 @@ impl ComputeUserInfoMaybeEndpoint {
}
}
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
let (subdomain, common_name) = sni.split_once('.')?;
pub(crate) fn endpoint_sni(
sni: &str,
common_names: &HashSet<String>,
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
let Some((subdomain, common_name)) = sni.split_once('.') else {
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
};
if !common_names.contains(common_name) {
return None;
return Err(ComputeUserInfoParseError::UnknownCommonName {
cn: common_name.into(),
});
}
if subdomain == SERVERLESS_DRIVER_SNI {
return None;
return Ok(None);
}
Some(EndpointId::from(subdomain))
Ok(Some(EndpointId::from(subdomain)))
}
impl ComputeUserInfoMaybeEndpoint {
@@ -100,8 +113,15 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let endpoint_from_domain =
sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
} else {
None
}
} else {
None
};
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
@@ -404,34 +424,21 @@ mod tests {
}
#[test]
fn parse_unknown_sni() {
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestContext::test();
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.unwrap();
assert!(info.endpoint_id.is_none());
}
#[test]
fn parse_unknown_sni_with_options() {
let options = StartupMessageParams::new([
("user", "john_doe"),
("options", "endpoint=foo-bar-baz-1234"),
]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestContext::test();
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.unwrap();
assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
}
_ => panic!("bad error: {err:?}"),
}
}
#[test]

View File

@@ -24,6 +24,9 @@ pub(crate) enum HandshakeError {
#[error("protocol violation")]
ProtocolViolation,
#[error("missing certificate")]
MissingCertificate,
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
@@ -39,6 +42,10 @@ impl ReportableError for HandshakeError {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
@@ -139,7 +146,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// try parse endpoint
let ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names));
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
}
@@ -154,8 +161,10 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(conn_info.server_name())
.ok_or(HandshakeError::MissingCertificate)?;
stream = PqStream {
framed: Framed {

View File

@@ -98,7 +98,8 @@ fn generate_tls_config<'a>(
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone_key())?;
let cert_resolver = CertResolver::new(key, vec![cert])?;
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = cert_resolver.get_common_names();

View File

@@ -199,7 +199,8 @@ fn get_conn_info(
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
endpoint_sni(hostname, &tls.common_names)?
.ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once('.')

View File

@@ -5,7 +5,6 @@ use anyhow::{Context, bail};
use itertools::Itertools;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::sign::CertifiedKey;
use x509_cert::der::{Reader, SliceReader};
use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
@@ -26,8 +25,10 @@ pub fn configure_tls(
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
@@ -39,8 +40,11 @@ pub fn configure_tls(
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver
.add_cert_path(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?;
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
@@ -79,42 +83,92 @@ pub fn configure_tls(
})
}
#[derive(Debug)]
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint),
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
fn parse_new(key_path: &str, cert_path: &str) -> anyhow::Result<Self> {
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
Self::new(priv_key, cert_chain)
pub fn new() -> Self {
Self::default()
}
pub fn new(
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
) -> anyhow::Result<Self> {
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let mut certs = HashMap::new();
let default = (cert.clone(), tls_server_end_point);
certs.insert(common_name, (cert, tls_server_end_point));
Ok(Self { certs, default })
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
fn add_cert_path(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> {
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
self.add_cert(priv_key, cert_chain)
}
fn add_cert(
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let certificate = SliceReader::new(first_cert)
.context("Failed to parse cerficiate")?
.decode::<x509_cert::Certificate>()
.context("Failed to parse cerficiate")?;
let common_name = certificate.tbs_certificate.subject.to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
@@ -123,82 +177,12 @@ impl CertResolver {
}
}
fn parse_key_cert(
key_path: &str,
cert_path: &str,
) -> anyhow::Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>)> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
)
})?
};
Ok((priv_key, cert_chain))
}
fn process_key_cert(
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
) -> anyhow::Result<(String, Arc<CertifiedKey>, TlsServerEndPoint)> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let certificate = SliceReader::new(first_cert)
.context("Failed to parse cerficiate")?
.decode::<x509_cert::Certificate>()
.context("Failed to parse cerficiate")?;
let common_name = certificate.tbs_certificate.subject.to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
Ok((common_name, cert, tls_server_end_point))
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
Some(self.resolve(client_hello.server_name()).0)
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
@@ -206,7 +190,7 @@ impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint) {
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
@@ -216,17 +200,12 @@ impl CertResolver {
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return cert.clone();
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
// The customer has some custom DNS mapping - just return
// a default certificate.
//
// This will error if the customer uses anything stronger
// than sslmode=require. That's a choice they can make.
return self.default.clone();
return None;
}
}
} else {

View File

@@ -202,8 +202,6 @@ def test_pageserver_gc_compaction_preempt(
env = neon_env_builder.init_start(initial_tenant_conf=conf)
env.pageserver.allowed_errors.append(".*The timeline or pageserver is shutting down.*")
env.pageserver.allowed_errors.append(".*flush task cancelled.*")
env.pageserver.allowed_errors.append(".*failed to pipe.*")
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline

View File

@@ -0,0 +1,11 @@
"""Test for detecting new flaky tests"""
import random
def test_flaky1():
assert random.random() > 0.05
def no_test_flaky2():
assert random.random() > 0.05

View File

@@ -11,6 +11,9 @@ if TYPE_CHECKING:
# Test that pageserver and safekeeper can restart quickly.
# This is a regression test, see https://github.com/neondatabase/neon/issues/2247
def test_fixture_restart(neon_env_builder: NeonEnvBuilder):
import random
assert random.random() > 0.05
env = neon_env_builder.init_start()
for _ in range(3):
@@ -20,3 +23,9 @@ def test_fixture_restart(neon_env_builder: NeonEnvBuilder):
for _ in range(3):
env.safekeepers[0].stop()
env.safekeepers[0].start()
def test_flaky3():
import random
assert random.random() > 0.05